diff --git a/README.md b/README.md index c2718f9f..ef8a6b8c 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ Documentation can be found on https://docs.aleph.im/tools/aleph-client/ Some cryptographic functionalities use curve secp256k1 and require installing [libsecp256k1](https://github.com/bitcoin-core/secp256k1). -> apt-get install -y python3-pip libsecp256k1-dev +> apt-get install -y python3-pip libsecp256k1-dev squashfs-tools ### macOs @@ -24,8 +24,7 @@ installing [libsecp256k1](https://github.com/bitcoin-core/secp256k1). ### Windows -The software is not tested on Windows, but should work using -the Windows Subsystem for Linux (WSL). +We recommend using [WSL](https://learn.microsoft.com/en-us/windows/wsl/install) (Windows Subsystem for Linux). ## Installation @@ -85,28 +84,15 @@ To install from source and still be able to modify the source code: ## Updating the User Documentation -The user documentation for Aleph is maintained in the `aleph-docs` repository. When releasing a new version, it's -important to update the documentation as part of the release process. - -### Steps for Updating Documentation - -Documentation is generated using the `typer` command. +The user documentation for Aleph is maintained in the [aleph-docs](https://github.com/aleph-im/aleph-docs) repository. The CLI page is generated using the `typer` command. When releasing a new version, it's important to update the documentation as part of the release process. If you have the `aleph-docs` repository cloned as a sibling folder to your current directory, you can use the following command to generate updated documentation: ```shell -./scripts/gendoc.py src/aleph_client/__main__.py docs --name aleph --title 'Aleph CLI Documentation' --output ../aleph-docs/docs/tools/aleph-client/usage.md +./scripts/gendoc.py src/aleph_client/__main__.py docs \ + --name aleph --title 'Aleph CLI Documentation' \ + --output ../aleph-docs/docs/tools/aleph-client/usage.md ``` -After generating the documentation, you may need to update the path for the private key, as this depends on the user -configuration. This can be fixed manually using the `sed` command. For example: - -```shell -sed -i 's#/home/olivier/.aleph-im/private-keys/sol2.key#~/.aleph-im/private-keys/ethereum.key#' ../aleph-docs/docs/tools/aleph-client/usage.md -``` - -This command replaces any hardcoded private key paths with the correct configuration path ( -`~/.aleph-im/private-keys/ethereum.key`). - -Once the documentation is updated, open a Pull Request (PR) on the `aleph-docs` repository with your changes. \ No newline at end of file +Then, open a Pull Request (PR) on the [aleph-docs](https://github.com/aleph-im/aleph-docs/pulls) repository with your changes. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a9692775..dd6ecbc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,6 +208,7 @@ pythonpath = [ testpaths = [ "tests", ] +asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] branch = true diff --git a/scripts/gendoc.py b/scripts/gendoc.py index 58b5f6d9..916f3a8d 100755 --- a/scripts/gendoc.py +++ b/scripts/gendoc.py @@ -3,15 +3,17 @@ Copied from typer.cli.py to customise doc generation """ -import click import importlib.util +import os import re import sys +from pathlib import Path +from typing import Any, List, Optional + +import click import typer import typer.core from click import Command, Group -from pathlib import Path -from typing import Any, List, Optional default_app_names = ("app", "cli", "main") default_func_names = ("main", "cli", "app") @@ -246,6 +248,19 @@ def get_docs_for_click( return docs +def replace_local_values(text: str) -> str: + # Replace username + current_user = Path.home().owner() + text = text.replace(current_user, "$USER") + + # Replace private key file path + pattern = r"[^/]+\.key" + replacement = r"ethereum.key" + text = re.sub(pattern, replacement, text) + + return text + + @utils_app.command() def docs( ctx: typer.Context, @@ -269,13 +284,14 @@ def docs( typer.echo("No Typer app found", err=True) raise typer.Abort() click_obj = typer.main.get_command(typer_obj) - docs = get_docs_for_click(obj=click_obj, ctx=ctx, name=name, title=title) - clean_docs = f"{docs.strip()}\n" + generated_docs = get_docs_for_click(obj=click_obj, ctx=ctx, name=name, title=title) + clean_docs = f"{generated_docs.strip()}\n" + fixed_docs = replace_local_values(clean_docs) if output: - output.write_text(clean_docs) + output.write_text(fixed_docs) typer.echo(f"Docs saved to: {output}") else: - typer.echo(clean_docs) + typer.echo(fixed_docs) utils_app() diff --git a/src/aleph_client/__main__.py b/src/aleph_client/__main__.py index e0d5b685..c082ab69 100644 --- a/src/aleph_client/__main__.py +++ b/src/aleph_client/__main__.py @@ -17,20 +17,19 @@ app = AsyncTyper(no_args_is_help=True) -app.add_typer(account.app, name="account", help="Manage account") -app.add_typer(aggregate.app, name="aggregate", help="Manage aggregate messages on aleph.im") -app.add_typer(files.app, name="file", help="File uploading and pinning on IPFS and aleph.im") +app.add_typer(account.app, name="account", help="Manage accounts") app.add_typer( message.app, name="message", - help="Post, amend, watch and forget messages on aleph.im", + help="Manage messages (post, amend, watch and forget) on aleph.im & twentysix.cloud", ) -app.add_typer(program.app, name="program", help="Upload and update programs on aleph.im VM") +app.add_typer(aggregate.app, name="aggregate", help="Manage aggregate messages on aleph.im & twentysix.cloud") +app.add_typer(files.app, name="file", help="Manage files (upload and pin on IPFS) on aleph.im & twentysix.cloud") +app.add_typer(program.app, name="program", help="Manage programs (micro-VMs) on aleph.im & twentysix.cloud") +app.add_typer(instance.app, name="instance", help="Manage instances (VMs) on aleph.im & twentysix.cloud") +app.add_typer(domain.app, name="domain", help="Manage custom domain (DNS) on aleph.im & twentysix.cloud") +app.add_typer(node.app, name="node", help="Get node info on aleph.im & twentysix.cloud") app.add_typer(about.app, name="about", help="Display the informations of Aleph CLI") -app.add_typer(node.app, name="node", help="Get node info on aleph.im network") -app.add_typer(domain.app, name="domain", help="Manage custom Domain (dns) on aleph.im") -app.add_typer(instance.app, name="instance", help="Manage instances (VMs) on aleph.im network") - if __name__ == "__main__": app() diff --git a/src/aleph_client/commands/account.py b/src/aleph_client/commands/account.py index da7f890d..570ce7aa 100644 --- a/src/aleph_client/commands/account.py +++ b/src/aleph_client/commands/account.py @@ -24,8 +24,10 @@ from aleph.sdk.utils import bytes_from_hex from aleph_message.models import Chain from rich.console import Console +from rich.panel import Panel from rich.prompt import Prompt from rich.table import Table +from rich.text import Text from typer.colors import GREEN, RED from aleph_client.commands import help_strings @@ -233,7 +235,7 @@ def sign_bytes( if not message: message = input_multiline() - coroutine = account.sign_raw(message.encode()) + coroutine = account.sign_raw(str(message).encode()) signature = asyncio.run(coroutine) typer.echo("\nSignature: " + signature.hex()) @@ -259,15 +261,44 @@ async def balance( if response.status == 200: balance_data = await response.json() balance_data["available_amount"] = balance_data["balance"] - balance_data["locked_amount"] - typer.echo( - "\n" - + f"Address: {balance_data['address']}\n" - + f"Balance: {balance_data['balance']:.2f}".rstrip("0").rstrip(".") - + "\n" - + f" - Locked: {balance_data['locked_amount']:.2f}".rstrip("0").rstrip(".") - + "\n" - + f" - Available: {balance_data['available_amount']:.2f}".rstrip("0").rstrip(".") - + "\n" + + infos = [ + Text.from_markup(f"Address: [bright_cyan]{balance_data['address']}[/bright_cyan]"), + Text.from_markup( + f"\nBalance: [bright_cyan]{balance_data['balance']:.2f}".rstrip("0").rstrip(".") + + "[/bright_cyan]" + ), + ] + details = balance_data.get("details") + if details: + infos += [Text("\n ↳ Details")] + for chain, chain_balance in details.items(): + infos += [ + Text.from_markup( + f"\n {chain}: [orange3]{chain_balance:.2f}".rstrip("0").rstrip(".") + "[/orange3]" + ) + ] + available_color = "bright_cyan" if balance_data["available_amount"] >= 0 else "red" + infos += [ + Text.from_markup( + f"\n - Locked: [bright_cyan]{balance_data['locked_amount']:.2f}".rstrip("0").rstrip(".") + + "[/bright_cyan]" + ), + Text.from_markup( + f"\n - Available: [{available_color}]{balance_data['available_amount']:.2f}".rstrip("0").rstrip( + "." + ) + + f"[/{available_color}]" + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Account Infos", + border_style="bright_cyan", + expand=False, + title_align="left", + ) ) else: typer.echo(f"Failed to retrieve balance for address {address}. Status code: {response.status}") diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index 095202b8..36848b3f 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -11,8 +11,9 @@ from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient from aleph.sdk.account import _load_account from aleph.sdk.conf import settings -from aleph.sdk.types import AccountFromPrivateKey, StorageEnum -from aleph_message.models import ItemHash, StoreMessage +from aleph.sdk.types import AccountFromPrivateKey, StorageEnum, StoredContent +from aleph.sdk.utils import safe_getattr +from aleph_message.models import ItemHash, ItemType, MessageType, StoreMessage from aleph_message.status import MessageStatus from pydantic import BaseModel, Field from rich import box @@ -101,28 +102,42 @@ async def download( output_path: Path = typer.Option(Path("."), help="Output directory path"), file_name: str = typer.Option(None, help="Output file name (without extension)"), file_extension: str = typer.Option(None, help="Output file extension"), + only_info: bool = False, + verbose: bool = True, debug: bool = False, -): - """Download a file on aleph.im.""" +) -> Optional[StoredContent]: + """Download a file from aleph.im or display related infos.""" setup_logging(debug) - output_path.mkdir(parents=True, exist_ok=True) + if not only_info: + output_path.mkdir(parents=True, exist_ok=True) - file_name = file_name if file_name else hash - file_extension = file_extension if file_extension else "" + file_name = file_name if file_name else hash + file_extension = file_extension if file_extension else "" - output_file_path = output_path / f"{file_name}{file_extension}" + output_file_path = output_path / f"{file_name}{file_extension}" - async with AlephHttpClient(api_server=settings.API_HOST) as client: - logger.info(f"Downloading {hash} ...") - with open(output_file_path, "wb") as fd: - if not use_ipfs: - await client.download_file_to_buffer(hash, fd) - else: - await client.download_file_ipfs_to_buffer(hash, fd) + async with AlephHttpClient(api_server=settings.API_HOST) as client: + logger.info(f"Downloading {hash} ...") + with open(output_file_path, "wb") as fd: + if not use_ipfs: + await client.download_file_to_buffer(hash, fd) + else: + await client.download_file_ipfs_to_buffer(hash, fd) - logger.debug("File downloaded successfully.") + logger.debug("File downloaded successfully.") + else: + async with AlephHttpClient(api_server=settings.API_HOST) as client: + content = await client.get_stored_content(hash) + if verbose: + typer.echo( + f"Filename: {content.filename}\nHash: {content.hash}\nURL: {content.url}" + if safe_getattr(content, "url") + else safe_getattr(content, "error") + ) + return content + return None @app.command() diff --git a/src/aleph_client/commands/help_strings.py b/src/aleph_client/commands/help_strings.py index d6dfd1ee..0466ee15 100644 --- a/src/aleph_client/commands/help_strings.py +++ b/src/aleph_client/commands/help_strings.py @@ -2,22 +2,18 @@ CHANNEL = "Aleph.im network channel where the message is or will be broadcasted" PRIVATE_KEY = "Your private key. Cannot be used with --private-key-file" PRIVATE_KEY_FILE = "Path to your private key file" -REF = "Checkout https://aleph-im.gitbook.io/aleph-js/api-resources-reference/posts" +REF = "Item hash of the message to update" SIGNABLE_MESSAGE = "Message to sign" CUSTOM_DOMAIN_TARGET_TYPES = "IPFS|PROGRAM|INSTANCE" CUSTOM_DOMAIN_OWNER_ADDRESS = "Owner address, default current account" CUSTOM_DOMAIN_NAME = "Domain name. ex: aleph.im" CUSTOM_DOMAIN_ITEM_HASH = "Item hash" SKIP_VOLUME = "Skip prompt to attach more volumes" -PERSISTENT_VOLUME = """Persistent volumes are allocated on the host machine and are not deleted when the VM is stopped.\n -Requires at least "name", "persistence", "mount" and "size_mib". For more info, see the docs: https://docs.aleph.im/computing/volumes/persistent/\n -Example: --persistent_volume name=data,persistence=host,size_mib=100,mount=/opt/data""" -EPHEMERAL_VOLUME = """Ephemeral volumes are allocated on the host machine when the VM is started and deleted when the VM is stopped.\n -Requires at least "name", "mount" and "size_mib".\n -Example: --ephemeral-volume name=temp,size_mib=100,mount=/tmp/data""" -IMMUTABLE_VOLUME = """Immutable volumes are pinned on the network and can be used by multiple VMs at the same time. They are read-only and useful for setting up libraries or other dependencies.\n -Requires at least "name", "ref" (message hash) and "mount" path. "use_latest" is True by default, to use the latest version of the volume, if it has been amended. See the docs for more info: https://docs.aleph.im/computing/volumes/immutable/\n -Example: --immutable-volume name=libs,ref=25a393222692c2f73489dc6710ae87605a96742ceef7b91de4d7ec34bb688d94,mount=/lib/python3.8/site-packages""" +PERSISTENT_VOLUME = "Persistent volumes are allocated on the host machine and are not deleted when the VM is stopped.\nRequires at least `name`, `persistence`, `mount` and `size_mib`. For more info, see the docs: https://docs.aleph.im/computing/volumes/persistent/\nExample: --persistent_volume name=data,persistence=host,size_mib=100,mount=/opt/data" +EPHEMERAL_VOLUME = "Ephemeral volumes are allocated on the host machine when the VM is started and deleted when the VM is stopped.\nRequires at least `name`, `mount` and `size_mib`.\nExample: --ephemeral-volume name=temp,size_mib=100,mount=/tmp/data" +IMMUTABLE_VOLUME = "Immutable volumes are pinned on the network and can be used by multiple VMs at the same time. They are read-only and useful for setting up libraries or other dependencies.\nRequires at least `name`, `ref` (message hash) and `mount` path. `use_latest` is True by default, to use the latest version of the volume, if it has been amended. See the docs for more info: https://docs.aleph.im/computing/volumes/immutable/\nExample: --immutable-volume name=libs,ref=25a3...8d94,mount=/lib/python3.11/site-packages" +SKIP_ENV_VAR = "Skip prompt to set environment variables" +ENVIRONMENT_VARIABLES = "Environment variables to pass. They will be public and visible in the message, so don't include secrets. Must be a comma separated list. Example: `KEY=value` or `KEY=value,KEY=value`" ASK_FOR_CONFIRMATION = "Prompt user for confirmation" IPFS_CATCH_ALL_PATH = "Choose a relative path to catch all unmatched route or a 404 error" PAYMENT_TYPE = "Payment method, either holding tokens, NFTs, or Pay-As-You-Go via token streaming" @@ -46,7 +42,7 @@ VM_ID = "Item hash of your VM. If provided, skip the instance creation, else create a new one" VM_NOT_READY = "VM not initialized/started" VM_SCHEDULED = "VM scheduled but not available yet" -VM_NOT_AVAILABLE_YET = "VM not available yet" +CRN_UNKNOWN = "Unknown" CRN_PENDING = "Pending..." ALLOCATION_AUTO = "Auto - Scheduler" ALLOCATION_MANUAL = "Manual - Selection" @@ -56,3 +52,12 @@ ADDRESS_CHAIN = "Chain for the address" CREATE_REPLACE = "Overwrites private key file if it already exists" CREATE_ACTIVE = "Loads the new private key after creation" +PROMPT_CRN_URL = "URL of the CRN (Compute node) on which the instance is running" +PROMPT_PROGRAM_CRN_URL = "URL of the CRN (Compute node) on which the program is running" +PROGRAM_PATH = "Path to your source code. Can be a directory, a .squashfs file or a .zip archive" +PROGRAM_ENTRYPOINT = "Your program entrypoint. Example: `main:app` for Python programs, else `run.sh` for a script containing your launch command" +PROGRAM_RUNTIME = "Hash of the runtime to use for your program. You can also create your own runtime and pin it. Currently defaults to `{runtime_id}` (Use `aleph program runtime-checker` to inspect it)" +PROGRAM_BETA = "If true, you will be prompted to add message subscriptions to your program" +PROGRAM_UPDATABLE = "Allow program updates. By default, only the source code can be modified without requiring redeployement (same item hash). When enabled (set to True), this option allows to update any other field. However, such modifications will require a program redeployment (new item hash)" +PROGRAM_KEEP_CODE = "Keep the source code intact instead of deleting it" +PROGRAM_KEEP_PREV = "Keep the previous program intact instead of deleting it" diff --git a/src/aleph_client/commands/instance/__init__.py b/src/aleph_client/commands/instance/__init__.py index c5cecb2f..64d729b0 100644 --- a/src/aleph_client/commands/instance/__init__.py +++ b/src/aleph_client/commands/instance/__init__.py @@ -26,12 +26,10 @@ from aleph.sdk.query.filters import MessageFilter from aleph.sdk.query.responses import PriceResponse from aleph.sdk.types import StorageEnum -from aleph.sdk.utils import calculate_firmware_hash -from aleph_message.models import InstanceMessage, StoreMessage -from aleph_message.models.base import Chain, MessageType +from aleph.sdk.utils import calculate_firmware_hash, safe_getattr +from aleph_message.models import Chain, InstanceMessage, MessageType, StoreMessage from aleph_message.models.execution.base import Payment, PaymentType from aleph_message.models.execution.environment import ( - GpuDeviceClass, GpuProperties, HostRequirements, HypervisorType, @@ -42,6 +40,7 @@ from click import echo from rich import box from rich.console import Console +from rich.panel import Panel from rich.prompt import Confirm, Prompt from rich.table import Table from rich.text import Text @@ -57,8 +56,8 @@ from aleph_client.commands.node import NodeInfo, _fetch_nodes from aleph_client.commands.utils import ( filter_only_valid_messages, + find_sevctl_or_exit, get_or_prompt_volumes, - safe_getattr, setup_logging, str_to_datetime, validate_ssh_pubkey_file, @@ -66,6 +65,7 @@ validated_prompt, wait_for_confirmed_flow, wait_for_processed_instance, + yes_no_input, ) from aleph_client.models import CRNInfo from aleph_client.utils import AsyncTyper, sanitize_url @@ -76,6 +76,11 @@ # TODO: This should be put on the API to get always from there FLOW_INSTANCE_PRICE_PER_SECOND = Decimal(0.0000155) # 0.055/h +hold_chains = get_chains_with_holding() + [Chain.SOL] +super_token_chains = get_chains_with_super_token() +metavar_valid_chains = f"[{'|'.join(hold_chains)}]" +metavar_valid_payment_types = f"[{'|'.join(PaymentType)}|nft]" + @app.command() async def create( @@ -83,11 +88,14 @@ async def create( None, help=help_strings.PAYMENT_TYPE, callback=lambda pt: None if pt is None else pt.lower(), - # callback=lambda pt: None if pt is None else PaymentType.hold if pt == "nft" else PaymentType(pt), - metavar=f"[{'|'.join(PaymentType)}|nft]", + metavar=metavar_valid_payment_types, + case_sensitive=False, ), payment_chain: Optional[Chain] = typer.Option( - None, help=help_strings.PAYMENT_CHAIN, metavar=f"[{'|'.join([Chain.ETH, Chain.AVAX, Chain.BASE, Chain.SOL])}]" + None, + help=help_strings.PAYMENT_CHAIN, + metavar=metavar_valid_chains, + case_sensitive=False, ), hypervisor: Optional[HypervisorType] = typer.Option(HypervisorType.qemu, help=help_strings.HYPERVISOR), name: Optional[str] = typer.Option(None, help=help_strings.INSTANCE_NAME), @@ -120,7 +128,7 @@ async def create( channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), - print_messages: bool = typer.Option(False), + print_message: bool = typer.Option(False), verbose: bool = typer.Option(True), debug: bool = False, ) -> Tuple[ItemHash, Optional[str], Chain]: @@ -145,6 +153,7 @@ async def create( config = load_main_configuration(settings.CONFIG_FILE) if config is not None: payment_chain = config.chain + console.print(f"Preset to default chain: [green]{payment_chain}[/green]") else: console.print("No active chain selected in configuration.") @@ -158,28 +167,27 @@ async def create( # Force-switches if NFT payment-type if payment_type == "nft": - payment_chain = Chain.AVAX payment_type = PaymentType.hold - console.print( - "[yellow]NFT[/yellow] payment-type selected: Auto-switch to [cyan]AVAX[/cyan] with [red]HOLD[/red]" + payment_chain = Chain( + Prompt.ask( + "On which chain did you claim your NFT voucher?", + choices=[Chain.AVAX.value, Chain.BASE.value, Chain.SOL.value], + default=Chain.AVAX.value, + ) ) elif payment_type in [ptype.value for ptype in PaymentType]: payment_type = PaymentType(payment_type) else: raise ValueError(f"Invalid payment-type: {payment_type}") - is_stream = payment_type != PaymentType.hold - hold_chains = get_chains_with_holding() + [Chain.SOL.value] - super_token_chains = get_chains_with_super_token() - # Checks if payment-chain is compatible with PAYG + is_stream = payment_type != PaymentType.hold if is_stream: - if payment_chain == Chain.SOL: - console.print( - "[yellow]SOL[/yellow] chain selected: [red]Not compatible yet with Pay-As-You-Go.[/red]\nChange your configuration or provide another chain using arguments (but EVM address will be used)." - ) - raise typer.Exit(code=1) - elif payment_chain is None or payment_chain not in super_token_chains: + if payment_chain is None or payment_chain not in super_token_chains: + if payment_chain: + console.print( + f"[red]{safe_getattr(payment_chain, 'value') or payment_chain}[/red] incompatible with Pay-As-You-Go." + ) payment_chain = Chain( Prompt.ask( "Which chain do you want to use for Pay-As-You-Go?", @@ -187,8 +195,12 @@ async def create( default=Chain.AVAX.value, ) ) - # Fallback for Hold-tier if no config / no chain is set - elif payment_chain is None: + # Fallback for Hold-tier if no config / no chain is set / chain not in hold_chains + elif payment_chain is None or payment_chain not in hold_chains: + if payment_chain: + console.print( + f"[red]{safe_getattr(payment_chain, 'value') or payment_chain}[/red] incompatible with Hold-tier." + ) payment_chain = Chain( Prompt.ask( "Which chain do you want to use for Hold-tier?", @@ -270,23 +282,31 @@ async def create( # Validate rootfs message exist async with AlephHttpClient(api_server=settings.API_HOST) as client: - rootfs_message: StoreMessage = await client.get_message(item_hash=rootfs, message_type=StoreMessage) + rootfs_message: Optional[StoreMessage] = None + try: + rootfs_message = await client.get_message(item_hash=rootfs, message_type=StoreMessage) + except MessageNotFoundError: + echo(f"Given rootfs volume {rootfs} does not exist on aleph.im") + except ForgottenMessageError: + echo(f"Given rootfs volume {rootfs} has been deleted on aleph.im") if not rootfs_message: - echo("Given rootfs volume does not exist on aleph.im") raise typer.Exit(code=1) - if rootfs_size is None and rootfs_message.content.size: - rootfs_size = rootfs_message.content.size + elif rootfs_size is None: + rootfs_size = safe_getattr(rootfs_message, "content.size") # Validate confidential firmware message exist confidential_firmware_as_hash = None if confidential: async with AlephHttpClient(api_server=settings.API_HOST) as client: confidential_firmware_as_hash = ItemHash(confidential_firmware) - firmware_message: StoreMessage = await client.get_message( - item_hash=confidential_firmware, message_type=StoreMessage - ) - if not firmware_message: + firmware_message: Optional[StoreMessage] = None + try: + firmware_message = await client.get_message(item_hash=confidential_firmware, message_type=StoreMessage) + except MessageNotFoundError: echo("Confidential Firmware hash does not exist on aleph.im") + except ForgottenMessageError: + echo("Confidential Firmware hash has been deleted on aleph.im") + if not firmware_message: raise typer.Exit(code=1) name = name or validated_prompt("Instance name", lambda x: len(x) < 65) @@ -320,11 +340,28 @@ async def create( crn_name, score, reward_addr = "?", 0, "" nodes: NodeInfo = await _fetch_nodes() for node in nodes.nodes: - if node["address"].rstrip("/") == crn_url: - crn_name = node["name"] - score = node["score"] - reward_addr = node["stream_reward"] - break + found_node, hash_match = None, False + try: + if sanitize_url(node["address"]) == crn_url: + found_node = node + if found_node["hash"] == crn_hash: + hash_match = True + except aiohttp.InvalidURL: + logger.debug(f"Invalid URL for node `{node['hash']}`: {node['address']}") + if found_node: + if hash_match: + crn_name = found_node["name"] + score = found_node["score"] + reward_addr = found_node["stream_reward"] + break + else: + echo( + f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\n* Found CRN *\nUrl: {found_node['address']}\nHash: {found_node['hash']}\n\nMismatch between provided CRN and found CRN" + ) + raise typer.Exit(1) + if crn_name == "?": + echo(f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\nCRN not found in aggregate") + raise typer.Exit(1) crn_info = await fetch_crn_info(crn_url) if crn_info: crn = CRNInfo( @@ -343,9 +380,7 @@ async def create( ), gpu_support=bool(crn_info.get("computing", {}).get("ENABLE_GPU_SUPPORT", False)), ) - echo("\n* Selected CRN *") crn.display_crn_specs() - echo() except Exception as e: echo(f"Unable to fetch CRN config: {e}") raise typer.Exit(1) @@ -358,9 +393,8 @@ async def create( if not crn: # User has ctrl-c raise typer.Exit(1) - echo("\n* Selected CRN *") crn.display_crn_specs() - if not Confirm.ask("\nDeploy on this node ?"): + if not yes_no_input("Deploy on this node?", default=True): crn = None continue elif crn_url or crn_hash: @@ -368,51 +402,67 @@ async def create( f"`--crn-url` and/or `--crn-hash` arguments have been ignored.\nHold-tier regular instances are scheduled automatically on available CRNs by the Aleph.im network." ) - gpu_requirement = None + requirements, trusted_execution, gpu_requirement = None, None, None if crn: - stream_reward_address = crn.stream_reward_address if hasattr(crn, "stream_reward_address") else "" + stream_reward_address = safe_getattr(crn, "stream_reward_address") or "" if is_stream and not stream_reward_address: echo("Selected CRN does not have a defined receiver address.") raise typer.Exit(1) - if is_qemu and (not hasattr(crn, "qemu_support") or not crn.qemu_support): + if is_qemu and not safe_getattr(crn, "qemu_support"): echo("Selected CRN does not support QEMU hypervisor.") raise typer.Exit(1) - if confidential and (not hasattr(crn, "confidential_computing") or not crn.confidential_computing): - echo("Selected CRN does not support confidential computing.") - raise typer.Exit(1) - if gpu and (not hasattr(crn, "gpu_support") or not crn.gpu_support): - echo("Selected CRN does not support GPU computing.") - raise typer.Exit(1) + if confidential: + if not safe_getattr(crn, "confidential_computing"): + echo("Selected CRN does not support confidential computing.") + raise typer.Exit(1) + trusted_execution = TrustedExecutionEnvironment(firmware=confidential_firmware_as_hash) if gpu: + if not safe_getattr(crn, "gpu_support"): + echo("Selected CRN does not support GPU computing.") + raise typer.Exit(1) if crn.machine_usage and crn.machine_usage.gpu: if len(crn.machine_usage.gpu.available_devices) < 1: echo("Selected CRN does not have any GPUs available.") raise typer.Exit(1) - echo("Select GPU to use:") - table = Table(box=box.SIMPLE_HEAVY) - table.add_column("Number", style="white", overflow="fold") + table = Table(box=box.ROUNDED) + table.add_column("Id", style="white", overflow="fold") table.add_column("Vendor", style="blue") - table.add_column("Model", style="magenta") + table.add_column("Model GPU", style="magenta") available_gpus = crn.machine_usage.gpu.available_devices for index, available_gpu in enumerate(available_gpus): - table.add_row(str(index), available_gpu.vendor, available_gpu.device_name) + table.add_row(str(index + 1), available_gpu.vendor, available_gpu.device_name) table.add_section() console.print(table) - selected_gpu_number = validated_int_prompt( - "GPU number to use", min_value=0, max_value=len(available_gpus) - 1 + selected_gpu_number = ( + validated_int_prompt("GPU Id to use", min_value=1, max_value=len(available_gpus)) - 1 ) selected_gpu = available_gpus[selected_gpu_number] - console.print(f"Selected GPU from vendor {selected_gpu.vendor} model {selected_gpu.device_name}") + gpu_selection = Text.from_markup( + f"[orange3]Vendor[/orange3]: {selected_gpu.vendor}\n[orange3]Model[/orange3]: {selected_gpu.device_name}" + ) + console.print( + Panel( + gpu_selection, + title="Selected GPU", + border_style="bright_cyan", + expand=False, + title_align="left", + ) + ) gpu_requirement = [ GpuProperties( vendor=selected_gpu.vendor, device_name=selected_gpu.device_name, - device_class=GpuDeviceClass(selected_gpu.device_class), + device_class=selected_gpu.device_class, device_id=selected_gpu.device_id, ) ] + requirements = HostRequirements( + node=NodeRequirements(node_hash=crn.hash), + gpu=gpu_requirement, + ) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: payment = Payment( @@ -435,17 +485,8 @@ async def create( ssh_keys=[ssh_pubkey], hypervisor=hypervisor, payment=payment, - requirements=( - HostRequirements( - node=NodeRequirements(node_hash=crn.hash), - gpu=gpu_requirement, - ) - if crn - else None - ), - trusted_execution=( - TrustedExecutionEnvironment(firmware=confidential_firmware_as_hash) if confidential else None - ), + requirements=requirements, + trusted_execution=trusted_execution, ) except InsufficientFundsError as e: echo( @@ -453,11 +494,14 @@ async def create( f"{account.get_address()} on {account.CHAIN} has {e.available_funds} ALEPH but needs {e.required_funds} ALEPH." ) raise typer.Exit(code=1) - if print_messages: + except Exception as e: + echo(f"Instance creation failed:\n{e}") + raise typer.Exit(code=1) + if print_message: echo(f"{message.json(indent=4)}") item_hash: ItemHash = message.item_hash - item_hash_text = Text(item_hash, style="bright_cyan") + infos = [] # Instances that need to be started by notifying a specific CRN crn_url = crn.url if crn and crn.url else None @@ -491,8 +535,22 @@ async def create( # Wait for the flow transaction to be confirmed await wait_for_confirmed_flow(account, message.content.payment.receiver) if flow_hash: - echo( - f"Flow {flow_hash} has been created:\n - Aleph cost summary:\n {price.required_tokens:.7f}/sec | {3600*price.required_tokens:.2f}/hour | {86400*price.required_tokens:.2f}/day | {2592000*price.required_tokens:.2f}/month\n - CRN receiver address: {crn.stream_reward_address}" + flow_info = "\n".join( + f"[orange3]{key}[/orange3]: {value}" + for key, value in { + "Hash": flow_hash, + "Aleph cost": f"{price.required_tokens:.7f}/sec | {3600*price.required_tokens:.2f}/hour | {86400*price.required_tokens:.2f}/day | {2592000*price.required_tokens:.2f}/month", + "CRN receiver address": crn.stream_reward_address, + }.items() + ) + console.print( + Panel( + flow_info, + title="Flow Created", + border_style="violet", + expand=False, + title_align="left", + ) ) # Notify CRN @@ -502,54 +560,65 @@ async def create( if int(status) != 200: echo(f"Could not allocate instance {item_hash} on CRN.") return item_hash, crn_url, payment_chain - console.print(f"Your instance {item_hash_text} has been deployed on aleph.im.") + + infos += [ + Text.from_markup(f"Your instance [bright_cyan]{item_hash}[/bright_cyan] has been deployed on aleph.im.") + ] if verbose: # PAYG-tier non-confidential instances if not confidential: - console.print( - "\n\nTo get the IPv6 address of the instance, check out:\n\n", + infos += [ Text.assemble( - " aleph instance list\n", - style="italic", - ), - ) + "\n\nTo get your instance's IPv6, check out:\n", + Text.assemble( + "↳ aleph instance list", + style="italic", + ), + "\n\nTo access your instance's logs, use:\n", + Text.from_markup( + f"↳ aleph instance logs [bright_cyan]{item_hash}[/bright_cyan]", + style="italic", + ), + ) + ] # All confidential instances else: - console.print( - "\n\nInitialize a confidential session using:\n\n", - # Text.assemble( - # " aleph instance confidential-init-session ", - # item_hash_text, - # style="italic", - # ), - # "\n\nThen start it using:\n\n", - # Text.assemble( - # " aleph instance confidential-start ", - # item_hash_text, - # style="italic", - # ), - # "\n\nOr just use the all-in-one command:\n\n", + infos += [ Text.assemble( - " aleph instance confidential ", - item_hash_text, - "\n", - style="italic", - ), - ) + "\n\nInitialize/start your confidential instance with:\n", + Text.from_markup( + f"↳ aleph instance confidential [bright_cyan]{item_hash}[/bright_cyan]", + style="italic", + ), + ) + ] # Instances started automatically by the scheduler (hold-tier non-confidential) else: - console.print( - f"Your instance {item_hash_text} is registered to be deployed on aleph.im.", - "\nThe scheduler usually takes a few minutes to set it up and start it.", - ) + infos += [ + Text.from_markup( + f"Your instance [bright_cyan]{item_hash}[/bright_cyan] is registered to be deployed on aleph.im.\nThe scheduler usually takes a few minutes to set it up and start it." + ) + ] if verbose: - console.print( - "\n\nTo get the IPv6 address of the instance, check out:\n\n", + infos += [ Text.assemble( - " aleph instance list\n", - style="italic", - ), - ) + "\n\nTo get your instance's IPv6, check out:\n", + Text.assemble( + "↳ aleph instance list", + style="italic", + ), + "\n\nTo access your instance's logs, use:\n", + Text.from_markup( + f"↳ aleph instance logs [bright_cyan]{item_hash}[/bright_cyan]", + style="italic", + ), + ) + ] + console.print( + Panel( + Text.assemble(*infos), title="Instance Created", border_style="green", expand=False, title_align="left" + ) + ) return item_hash, crn_url, payment_chain @@ -557,8 +626,8 @@ async def create( async def delete( item_hash: str = typer.Argument(..., help="Instance item hash to forget"), reason: str = typer.Option("User deletion", help="Reason for deleting the instance"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.ADDRESS_CHAIN), - crn_url: Optional[str] = typer.Option(None, help=help_strings.CRN_URL_VM_DELETION), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), + domain: Optional[str] = typer.Option(None, help=help_strings.CRN_URL_VM_DELETION), private_key: Optional[str] = settings.PRIVATE_KEY_STRING, private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, print_message: bool = typer.Option(False), @@ -598,23 +667,23 @@ async def delete( node_list: NodeInfo = await _fetch_nodes() _, info = await fetch_vm_info(existing_message, node_list) auto_scheduled = info["allocation_type"] == help_strings.ALLOCATION_AUTO - crn_url = str(info["crn_url"]) - if not auto_scheduled and crn_url: - try: - status = await erase( - vm_id=item_hash, - domain=crn_url, - chain=chain, - private_key=private_key, - private_key_file=private_key_file, - silent=True, - debug=debug, - ) - if status == 1: - echo(f"No associated VM on {crn_url}. Skipping...") - except Exception as e: - logger.debug(f"Error while deleting associated VM on {crn_url}: {str(e)}") - echo(f"Failed to erase associated VM on {crn_url}. Skipping...") + crn_url = (info["crn_url"] not in [help_strings.CRN_PENDING, help_strings.CRN_UNKNOWN] and info["crn_url"]) or ( + domain and sanitize_url(domain) + ) + if not auto_scheduled: + if not crn_url: + echo("CRN domain not found or invalid. Skipping...") + else: + try: + async with VmClient(account, crn_url) as manager: + status, _ = await manager.erase_instance(vm_id=item_hash) + if status == 200: + echo(f"VM erased on CRN: {crn_url}") + else: + echo(f"No associated VM on {crn_url}. Skipping...") + except Exception as e: + logger.debug(f"Error while deleting associated VM on {crn_url}: {str(e)}") + echo(f"Failed to erase associated VM on {crn_url}. Skipping...") else: echo(f"Instance {item_hash} was auto-scheduled, VM will be erased automatically.") @@ -636,9 +705,9 @@ async def delete( async def _show_instances(messages: List[InstanceMessage], node_list: NodeInfo): - table = Table(box=box.SIMPLE_HEAVY) + table = Table(box=box.ROUNDED, style="blue_violet") table.add_column(f"Instances [{len(messages)}]", style="blue", overflow="fold") - table.add_column("Specifications", style="magenta") + table.add_column("Specifications", style="blue") table.add_column("Logs", style="blue", overflow="fold") scheduler_responses = dict(await asyncio.gather(*[fetch_vm_info(message, node_list) for message in messages])) @@ -655,45 +724,35 @@ async def _show_instances(messages: List[InstanceMessage], node_list: NodeInfo): and "name" in message.content.metadata else "-" ), - style="orchid", + style="magenta3", ) link = f"https://explorer.aleph.im/address/ETH/{message.sender}/message/INSTANCE/{message.item_hash}" # link = f"{settings.API_HOST}/api/v0/messages/{message.item_hash}" item_hash_link = Text.from_markup(f"[link={link}]{message.item_hash}[/link]", style="bright_cyan") - is_hold = str(info["payment"]).startswith("hold") + is_hold = info["payment"] == "hold" payment = Text.assemble( "Payment: ", Text( - str(info["payment"]).capitalize(), + info["payment"].capitalize().ljust(12), style="red" if is_hold else "orange3", ), ) + confidential = Text.assemble( + "Type: ", Text("Confidential", style="green") if info["confidential"] else Text("Regular", style="grey50") + ) + chain = Text.assemble("Chain: ", Text(info["chain"].ljust(14), style="white")) + created_at = Text.assemble( + "Created at: ", Text(str(str_to_datetime(info["created_at"])).split(".", maxsplit=1)[0], style="orchid") + ) cost: Text | str = "" if not is_hold: async with AlephHttpClient(api_server=settings.API_HOST) as client: price: PriceResponse = await client.get_program_price(message.item_hash) - psec = Text(f"{price.required_tokens:.7f}/sec", style="bright_magenta") - phour = Text(f"{3600*price.required_tokens:.2f}/hour", style="bright_magenta") - pday = Text(f"{86400*price.required_tokens:.2f}/day", style="bright_magenta") - pmonth = Text(f"{2592000*price.required_tokens:.2f}/month", style="bright_magenta") - cost = Text.assemble("Aleph cost: ", psec, " | ", phour, " | ", pday, " | ", pmonth, "\n") - confidential = ( - Text.assemble("Type: ", Text("Confidential", style="green")) - if info["confidential"] - else Text.assemble("Type: ", Text("Regular", style="grey50")) - ) - chain_label, chain_color = str(info["chain"]), "steel_blue" - if chain_label == "AVAX": - chain_label, chain_color = "AVAX", "bright_red" - elif chain_label == "BASE": - chain_label, chain_color = "BASE", "blue3" - elif chain_label == "SOL": - chain_label, chain_color = "SOL ", "medium_spring_green" - else: # ETH - chain_label += " " - chain = Text.assemble("Chain: ", Text(chain_label, style=chain_color)) - created_at_parsed = str(str_to_datetime(str(info["created_at"]))).split(".")[0] - created_at = Text.assemble("\t Created at: ", Text(created_at_parsed, style="magenta")) + psec = Text(f"{price.required_tokens:.7f}/sec", style="magenta3") + phour = Text(f"{3600*price.required_tokens:.2f}/hour", style="magenta3") + pday = Text(f"{86400*price.required_tokens:.2f}/day", style="magenta3") + pmonth = Text(f"{2592000*price.required_tokens:.2f}/month", style="magenta3") + cost = Text.assemble("\nAleph cost: ", psec, " | ", phour, " | ", pday, " | ", pmonth) instance = Text.assemble( "Item Hash ↓\t Name: ", name, @@ -701,89 +760,91 @@ async def _show_instances(messages: List[InstanceMessage], node_list: NodeInfo): item_hash_link, "\n", payment, - " ", confidential, "\n", - cost, chain, created_at, + cost, ) - specifications = ( - f"vCPUs: {message.content.resources.vcpus}\n" - f"RAM: {message.content.resources.memory / 1_024:.2f} GiB\n" - f"Disk: {message.content.rootfs.size_mib / 1_024:.2f} GiB\n" - f"HyperV: {safe_getattr(message, 'content.environment.hypervisor.value').capitalize() if safe_getattr(message, 'content.environment.hypervisor') else 'Firecracker'}\n" - ) + hypervisor = safe_getattr(message, "content.environment.hypervisor") + specs = [ + f"vCPU: [magenta3]{message.content.resources.vcpus}[/magenta3]\n", + f"RAM: [magenta3]{message.content.resources.memory / 1_024:.2f} GiB[/magenta3]\n", + f"Disk: [magenta3]{message.content.rootfs.size_mib / 1_024:.2f} GiB[/magenta3]\n", + f"HyperV: [magenta3]{hypervisor.capitalize() if hypervisor else 'Firecracker'}[/magenta3]", + ] + gpus = safe_getattr(message, "content.requirements.gpu") + if gpus: + specs += [f"\n[bright_yellow]GPU [[green]{len(gpus)}[/green]]:\n"] + for gpu in gpus: + specs += [f"• [green]{gpu.vendor}, {gpu.device_name}[green]"] + specs += ["[/bright_yellow]"] + specifications = Text.from_markup("".join(specs)) status_column = Text.assemble( Text.assemble( Text("Allocation: ", style="blue"), Text( - str(info["allocation_type"]) + "\n", + info["allocation_type"] + "\n", style="magenta3" if info["allocation_type"] == help_strings.ALLOCATION_MANUAL else "deep_sky_blue1", ), ), + ( + Text.assemble( + Text("CRN Hash: ", style="blue"), + Text(info["crn_hash"] + "\n", style=("bright_cyan")), + ) + if info["crn_hash"] + else "" + ), Text.assemble( - Text("Target CRN: ", style="blue"), + Text("CRN Url: ", style="blue"), Text( - str(info["crn_url"]) + "\n", - style="green1" if str(info["crn_url"]).startswith("http") else "dark_slate_gray1", + info["crn_url"] + "\n", + style="green1" if info["crn_url"].startswith("http") else "grey50", ), ), Text.assemble( Text("IPv6: ", style="blue"), - Text(str(info["ipv6_logs"])), - style="bright_yellow" if len(str(info["ipv6_logs"]).split(":")) == 8 else "dark_orange", + Text(info["ipv6_logs"]), + style="bright_yellow" if len(info["ipv6_logs"].split(":")) == 8 else "dark_orange", ), ) table.add_row(instance, specifications, status_column) table.add_section() + console = Console() - console.print( - f"\n[bold]Address:[/bold] {messages[0].content.address}", - ) console.print(table) + + infos = [Text.from_markup(f"[bold]Address:[/bold] [bright_cyan]{messages[0].content.address}[/bright_cyan]")] if uninitialized_confidential_found: - item_hash_field = Text("", style="bright_cyan") - console.print( - "To start uninitialized confidential instance(s), use:\n\n", - # Text.assemble( - # " aleph instance confidential-init-session ", - # item_hash_field, - # "\n", - # style="italic", - # ), - # Text.assemble( - # " aleph instance confidential-start ", - # item_hash_field, - # style="italic", - # ), - # "\n\nOr just use the all-in-one command:\n\n", + infos += [ Text.assemble( - " aleph instance confidential ", - item_hash_field, - "\n", + "\n\nBoot uninitialized/unstarted confidential instances with:\n", + Text.from_markup( + "↳ aleph instance confidential [bright_cyan][/bright_cyan]", style="italic" + ), + ) + ] + infos += [ + Text.assemble( + "\n\nConnect to an instance with:\n", + Text.from_markup( + "↳ ssh root@[yellow][/yellow] [-i [orange3][/orange3]]", style="italic", ), ) + ] console.print( - "To connect to an instance, use:\n\n", - Text.assemble( - " ssh root@", - Text("", style="yellow"), - " -i ", - Text("", style="orange3"), - "\n", - style="italic", - ), + Panel(Text.assemble(*infos), title="Infos", border_style="bright_cyan", expand=False, title_align="left") ) -@app.command() -async def list( - address: Optional[str] = typer.Option(None, help="Owner address of the instance"), +@app.command(name="list") +async def list_instances( + address: Optional[str] = typer.Option(None, help="Owner address of the instances"), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), - chain: Optional[Chain] = typer.Option(None, help=help_strings.ADDRESS_CHAIN), + chain: Optional[Chain] = typer.Option(None, help=help_strings.ADDRESS_CHAIN, metavar=metavar_valid_chains), json: bool = typer.Option(default=False, help="Print as json instead of rich table"), debug: bool = False, ): @@ -808,7 +869,8 @@ async def list( echo(f"Address: {address}\n\nNo instance found\n") raise typer.Exit(code=1) if json: - echo(messages.json(indent=4)) + for message in messages: + echo(message.json(indent=4)) else: # Since we filtered on message type, we can safely cast as InstanceMessage. messages = cast(List[InstanceMessage], messages) @@ -816,71 +878,11 @@ async def list( await _show_instances(messages, resource_nodes) -@app.command() -async def expire( - vm_id: str = typer.Argument(..., help="VM item hash to expire"), - domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM is running"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), - private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), - private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), - debug: bool = False, -): - """Expire an instance""" - - setup_logging(debug) - - domain = ( - (domain and sanitize_url(domain)) - or await find_crn_of_vm(vm_id) - or Prompt.ask("URL of the CRN (Compute node) on which the VM is running") - ) - - account = _load_account(private_key, private_key_file, chain=chain) - - async with VmClient(account, domain) as manager: - status, result = await manager.expire_instance(vm_id=vm_id) - if status != 200: - echo(f"Status: {status}") - return 1 - echo(f"VM expired on CRN: {domain}") - - -@app.command() -async def erase( - vm_id: str = typer.Argument(..., help="VM item hash to erase"), - domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM is stored or running"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), - private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), - private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), - silent: bool = False, - debug: bool = False, -): - """Erase an instance stored or running on a CRN""" - - setup_logging(debug) - - domain = ( - (domain and sanitize_url(domain)) - or await find_crn_of_vm(vm_id) - or Prompt.ask("URL of the CRN (Compute node) on which the VM is stored or running") - ) - - account = _load_account(private_key, private_key_file, chain=chain) - - async with VmClient(account, domain) as manager: - status, result = await manager.erase_instance(vm_id=vm_id) - if status != 200: - if not silent: - echo(f"Status: {status}") - return 1 - echo(f"VM erased on CRN: {domain}") - - @app.command() async def reboot( vm_id: str = typer.Argument(..., help="VM item hash to reboot"), domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM is running"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), debug: bool = False, @@ -909,7 +911,7 @@ async def reboot( async def allocate( vm_id: str = typer.Argument(..., help="VM item hash to allocate"), domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM will be allocated"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), debug: bool = False, @@ -938,7 +940,7 @@ async def allocate( async def logs( vm_id: str = typer.Argument(..., help="VM item hash to retrieve the logs from"), domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM is running"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), debug: bool = False, @@ -946,11 +948,7 @@ async def logs( """Retrieve the logs of an instance""" setup_logging(debug) - domain = ( - (domain and sanitize_url(domain)) - or await find_crn_of_vm(vm_id) - or Prompt.ask("URL of the CRN (Compute node) on which the instance is running") - ) + domain = (domain and sanitize_url(domain)) or await find_crn_of_vm(vm_id) or Prompt.ask(help_strings.PROMPT_CRN_URL) account = _load_account(private_key, private_key_file, chain=chain) @@ -979,11 +977,7 @@ async def stop( setup_logging(debug) - domain = ( - (domain and sanitize_url(domain)) - or await find_crn_of_vm(vm_id) - or Prompt.ask("URL of the CRN (Compute node) on which the instance is running") - ) + domain = (domain and sanitize_url(domain)) or await find_crn_of_vm(vm_id) or Prompt.ask(help_strings.PROMPT_CRN_URL) account = _load_account(private_key, private_key_file, chain=chain) @@ -999,21 +993,21 @@ async def stop( async def confidential_init_session( vm_id: str = typer.Argument(..., help="VM item hash to initialize the session for"), domain: Optional[str] = typer.Option(None, help="CRN domain on which the session will be initialized"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), policy: int = typer.Option(default=0x1), keep_session: bool = typer.Option(None, help=help_strings.KEEP_SESSION), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), debug: bool = False, ): - "Initialize a confidential communication session with the VM" - assert settings.CONFIG_HOME + """Initialize a confidential communication session with the VM""" + setup_logging(debug) + + assert settings.CONFIG_HOME session_dir = Path(settings.CONFIG_HOME) / "confidential_sessions" / vm_id session_dir.mkdir(exist_ok=True, parents=True) - setup_logging(debug) - domain = ( (domain and sanitize_url(domain)) or await find_crn_of_vm(vm_id) @@ -1028,8 +1022,9 @@ async def confidential_init_session( godh_path = session_dir / "vm_godh.b64" if godh_path.exists() and keep_session is None: - keep_session = not Confirm.ask( - "Session already initiated for this instance, are you sure you want to override the previous one? You won't be able to communicate with already running VM" + keep_session = not yes_no_input( + "Session already initiated for this instance, are you sure you want to override the previous one? You won't be able to communicate with already running VM", + default=True, ) if keep_session: echo("Keeping already initiated session") @@ -1057,41 +1052,40 @@ async def confidential_init_session( godh_path = session_dir / "vm_godh.b64" session_path = session_dir / "vm_session.b64" assert godh_path.exists() - await client.initialize(vm_hash, session_path, godh_path) - echo("Confidential Session with VM and CRN initiated") + try: + await client.initialize(vm_hash, session_path, godh_path) + echo("Confidential Session with VM and CRN initiated") + except Exception as e: + await client.close() + echo(f"Failed to initiate confidential session with VM and CRN, reason:\n{e}") + return 1 await client.close() -def find_sevctl_or_exit() -> Path: - "Find sevctl in path, exit with message if not available" - sevctl_path = shutil.which("sevctl") - if sevctl_path is None: - echo("sevctl binary is not available. Please install sevctl, ensure it is in the PATH and try again.") - echo("Instructions for setup https://docs.aleph.im/computing/confidential/requirements/") - raise typer.Exit(code=1) - return Path(sevctl_path) - - @app.command() async def confidential_start( vm_id: str = typer.Argument(..., help="VM item hash to start"), domain: Optional[str] = typer.Option(None, help="CRN domain on which the VM will be started"), - chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED), + chain: Optional[Chain] = typer.Option(None, help=help_strings.PAYMENT_CHAIN_USED, metavar=metavar_valid_chains), firmware_hash: str = typer.Option( settings.DEFAULT_CONFIDENTIAL_FIRMWARE_HASH, help=help_strings.CONFIDENTIAL_FIRMWARE_HASH ), - firmware_file: str = typer.Option(None, help=help_strings.PRIVATE_KEY), + firmware_file: str = typer.Option(None, help=help_strings.CONFIDENTIAL_FIRMWARE_PATH), vm_secret: str = typer.Option(None, help=help_strings.VM_SECRET), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), + verbose: bool = typer.Option(True), debug: bool = False, ): - "Validate the authenticity of the VM and start it" + """Validate the authenticity of the VM and start it""" + + setup_logging(debug) + assert settings.CONFIG_HOME session_dir = Path(settings.CONFIG_HOME) / "confidential_sessions" / vm_id session_dir.mkdir(exist_ok=True, parents=True) - setup_logging(debug) + vm_hash = ItemHash(vm_id) account = _load_account(private_key, private_key_file, chain=chain) sevctl_path = find_sevctl_or_exit() @@ -1103,16 +1097,17 @@ async def confidential_start( client = VmConfidentialClient(account, sevctl_path, domain) - bytes.fromhex(firmware_hash) - - vm_hash = ItemHash(vm_id) - if not session_dir.exists(): echo("Please run confidential-init-session first ") return 1 - sev_data = await client.measurement(vm_hash) - echo("Retrieved measurement") + try: + sev_data = await client.measurement(vm_hash) + echo("Retrieved measurement") + except Exception as e: + await client.close() + echo(f"Failed to start the VM, reason:\n{e}") + return 1 tek_path = session_dir / "vm_tek.bin" tik_path = session_dir / "vm_tik.bin" @@ -1133,21 +1128,33 @@ async def confidential_start( secret_key = vm_secret or Prompt.ask("Please enter secret to start the VM", password=True) encoded_packet_header, encoded_secret = await client.build_secret(tek_path, tik_path, sev_data, secret_key) - await client.inject_secret(vm_hash, encoded_packet_header, encoded_secret) + try: + await client.inject_secret(vm_hash, encoded_packet_header, encoded_secret) + except Exception as e: + await client.close() + echo(f"Failed to start the VM, reason:\n{e}") + return 1 await client.close() + console = Console() + infos = [Text.from_markup(f"Your instance [bright_cyan]{vm_id}[/bright_cyan] is currently starting.")] + if verbose: + infos += [ + Text.assemble( + "\n\nTo get your instance's IPv6, check out:\n", + Text.assemble( + "↳ aleph instance list", + style="italic", + ), + "\n\nTo access your instance's logs, use:\n", + Text.from_markup( + f"↳ aleph instance logs [bright_cyan]{vm_id}[/bright_cyan]", + style="italic", + ), + ) + ] console.print( - "Your instance is currently starting...\n\nLogs can be fetched using:\n\n", - Text.assemble( - " aleph instance logs ", - Text(vm_id, style="bright_cyan"), - style="italic", - ), - "\n\nTo get the IPv6 address of the instance, check out:\n\n", - Text.assemble( - " aleph instance list\n", - style="italic", - ), + Panel(Text.assemble(*infos), title="Instance Started", border_style="green", expand=False, title_align="left") ) @@ -1163,18 +1170,21 @@ async def confidential_create( firmware_hash: str = typer.Option( settings.DEFAULT_CONFIDENTIAL_FIRMWARE_HASH, help=help_strings.CONFIDENTIAL_FIRMWARE_HASH ), - firmware_file: Optional[str] = typer.Option(None, help=help_strings.PRIVATE_KEY), + firmware_file: Optional[str] = typer.Option(None, help=help_strings.CONFIDENTIAL_FIRMWARE_PATH), keep_session: Optional[bool] = typer.Option(None, help=help_strings.KEEP_SESSION), vm_secret: Optional[str] = typer.Option(None, help=help_strings.VM_SECRET), payment_type: Optional[str] = typer.Option( None, help=help_strings.PAYMENT_TYPE, callback=lambda pt: None if pt is None else pt.lower(), - # callback=lambda pt: None if pt is None else PaymentType.hold if pt == "nft" else PaymentType(pt), - metavar=f"[{'|'.join(PaymentType)}|nft]", + metavar=metavar_valid_payment_types, + case_sensitive=False, ), payment_chain: Optional[Chain] = typer.Option( - None, help=help_strings.PAYMENT_CHAIN, metavar=f"[{'|'.join([Chain.ETH, Chain.AVAX, Chain.BASE, Chain.SOL])}]" + None, + help=help_strings.PAYMENT_CHAIN, + metavar=metavar_valid_chains, + case_sensitive=False, ), name: Optional[str] = typer.Option(None, help=help_strings.INSTANCE_NAME), rootfs: Optional[str] = typer.Option(None, help=help_strings.ROOTFS), @@ -1213,6 +1223,7 @@ async def confidential_create( # Ensure sevctl is accessible before we start process with user find_sevctl_or_exit() + allocated = False if not vm_id or len(vm_id) != 64: vm_id, crn_url, payment_chain = await create( @@ -1238,7 +1249,7 @@ async def confidential_create( channel=channel, private_key=private_key, private_key_file=private_key_file, - print_messages=False, + print_message=False, verbose=False, debug=debug, ) @@ -1261,9 +1272,7 @@ async def confidential_create( raise typer.Exit(code=1) crn_url = ( - (crn_url and sanitize_url(crn_url)) - or await find_crn_of_vm(vm_id) - or Prompt.ask("URL of the CRN (Compute node) on which the instance is running") + (crn_url and sanitize_url(crn_url)) or await find_crn_of_vm(vm_id) or Prompt.ask(help_strings.PROMPT_CRN_URL) ) if not allocated: @@ -1292,6 +1301,10 @@ async def confidential_create( echo("Could not initialize the session") return 1 + # Safe delay to ensure instance is starting and is ready + echo("Waiting 10sec before to start...") + await asyncio.sleep(10) + await confidential_start( vm_id=vm_id, domain=crn_url, @@ -1301,5 +1314,6 @@ async def confidential_create( vm_secret=vm_secret, private_key=private_key, private_key_file=private_key_file, + verbose=True, debug=debug, ) diff --git a/src/aleph_client/commands/instance/network.py b/src/aleph_client/commands/instance/network.py index 6394e7ef..0e92c282 100644 --- a/src/aleph_client/commands/instance/network.py +++ b/src/aleph_client/commands/instance/network.py @@ -8,16 +8,19 @@ import aiohttp from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings +from aleph.sdk.exceptions import ForgottenMessageError, MessageNotFoundError +from aleph.sdk.utils import safe_getattr from aleph_message.models import InstanceMessage from aleph_message.models.execution.base import PaymentType from aleph_message.models.item_hash import ItemHash +from click import echo from pydantic import ValidationError +from typer import Exit from aleph_client.commands import help_strings from aleph_client.commands.node import NodeInfo, _fetch_nodes -from aleph_client.commands.utils import safe_getattr from aleph_client.models import MachineUsage -from aleph_client.utils import AsyncTyper, fetch_json, sanitize_url +from aleph_client.utils import fetch_json, sanitize_url logger = logging.getLogger(__name__) @@ -68,7 +71,7 @@ async def fetch_crn_info(node_url: str) -> dict | None: return None -async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[str, dict[str, object]]: +async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[str, dict[str, str]]: """ Fetches VM information given an instance message and the node list. @@ -79,24 +82,29 @@ async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[ VM information. """ async with aiohttp.ClientSession() as session: - hold = not message.content.payment or message.content.payment.type == PaymentType["hold"] + chain = safe_getattr(message, "content.payment.chain.value") + hold = safe_getattr(message, "content.payment.type.value") crn_hash = safe_getattr(message, "content.requirements.node.node_hash") created_at = safe_getattr(message, "content.time") + + is_hold = hold == PaymentType.hold.value firmware = safe_getattr(message, "content.environment.trusted_execution.firmware") - confidential = firmware and len(firmware) == 64 + is_confidential = firmware and len(firmware) == 64 + has_gpu = safe_getattr(message, "content.requirements.gpu") + info = dict( crn_hash=str(crn_hash) if crn_hash else "", created_at=str(created_at), - payment="hold\t " if hold else str(safe_getattr(message, "content.payment.type.value")), - chain=str(safe_getattr(message, "content.payment.chain.value")), - confidential=confidential, + payment=str(hold), + chain=str(chain), + confidential=str(firmware) if is_confidential else "", allocation_type="", ipv6_logs="", crn_url="", ) try: # Fetch from the scheduler API directly if no payment or no receiver (hold-tier non-confidential) - if hold and not confidential: + if is_hold and not is_confidential and not has_gpu: try: url = f"https://scheduler.api.aleph.cloud/api/v0/allocation/{message.item_hash}" info["allocation_type"] = help_strings.ALLOCATION_AUTO @@ -107,24 +115,28 @@ async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[ for node in nodes["nodes"]: if node["ipv6"].split("::")[0] == ":".join(str(info["ipv6_logs"]).split(":")[:4]): info["crn_url"] = node["url"].rstrip("/") - return message.item_hash, info + break except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError) as e: - info["ipv6_logs"] = help_strings.VM_SCHEDULED info["crn_url"] = help_strings.CRN_PENDING + info["ipv6_logs"] = help_strings.VM_SCHEDULED logger.debug(f"Error while calling Scheduler API ({url}): {e}") else: - # Fetch from the CRN API if PAYG-tier or confidential + # Fetch from the CRN API if PAYG-tier or confidential or GPU info["allocation_type"] = help_strings.ALLOCATION_MANUAL for node in node_list.nodes: - if node["hash"] == safe_getattr(message, "content.requirements.node.node_hash"): + if node["hash"] == crn_hash: info["crn_url"] = node["address"].rstrip("/") - path = f"{node['address'].rstrip('/')}/about/executions/list" - executions = await fetch_json(session, path) - if message.item_hash in executions: - interface = IPv6Interface(executions[message.item_hash]["networking"]["ipv6"]) - info["ipv6_logs"] = str(interface.ip + 1) - return message.item_hash, info - info["ipv6_logs"] = help_strings.VM_NOT_READY if confidential else help_strings.VM_NOT_AVAILABLE_YET + break + if info["crn_url"]: + path = f"{info['crn_url']}/about/executions/list" + executions = await fetch_json(session, path) + if message.item_hash in executions: + interface = IPv6Interface(executions[message.item_hash]["networking"]["ipv6"]) + info["ipv6_logs"] = str(interface.ip + 1) + else: + info["crn_url"] = help_strings.CRN_UNKNOWN + if not info["ipv6_logs"]: + info["ipv6_logs"] = help_strings.VM_NOT_READY except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError) as e: info["ipv6_logs"] = f"Not available. Server error: {e}" return message.item_hash, info @@ -132,8 +144,16 @@ async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[ async def find_crn_of_vm(vm_id: str) -> Optional[str]: async with AlephHttpClient(api_server=settings.API_HOST) as client: - message: InstanceMessage = await client.get_message(item_hash=ItemHash(vm_id), message_type=InstanceMessage) + message: Optional[InstanceMessage] = None + try: + message = await client.get_message(item_hash=ItemHash(vm_id), message_type=InstanceMessage) + except MessageNotFoundError: + echo("Instance does not exist on aleph.im") + except ForgottenMessageError: + echo("Instance has been deleted on aleph.im") + if not message: + raise Exit(code=1) node_list: NodeInfo = await _fetch_nodes() _, info = await fetch_vm_info(message, node_list) - is_valid = info["crn_url"] and info["crn_url"] != help_strings.CRN_PENDING + is_valid = info["crn_url"] not in [help_strings.CRN_PENDING, help_strings.CRN_UNKNOWN] return str(info["crn_url"]) if is_valid else None diff --git a/src/aleph_client/commands/message.py b/src/aleph_client/commands/message.py index 00ee2bbf..981cc11f 100644 --- a/src/aleph_client/commands/message.py +++ b/src/aleph_client/commands/message.py @@ -14,6 +14,7 @@ from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient from aleph.sdk.account import _load_account from aleph.sdk.conf import settings +from aleph.sdk.exceptions import ForgottenMessageError, MessageNotFoundError from aleph.sdk.query.filters import MessageFilter from aleph.sdk.query.responses import MessagesResponse from aleph.sdk.types import AccountFromPrivateKey, StorageEnum @@ -42,13 +43,20 @@ async def get( item_hash: str = typer.Argument(..., help="Item hash of the message"), ): async with AlephHttpClient(api_server=settings.API_HOST) as client: - message, status = await client.get_message(item_hash=ItemHash(item_hash), with_status=True) - typer.echo(f"Message Status: {colorized_status(status)}") - if status == MessageStatus.REJECTED: - reason = await client.get_message_error(item_hash=ItemHash(item_hash)) - typer.echo(colorful_json(json.dumps(reason, indent=4))) - else: - typer.echo(colorful_message_json(message)) + message: Optional[AlephMessage] = None + try: + message, status = await client.get_message(item_hash=ItemHash(item_hash), with_status=True) + except MessageNotFoundError: + typer.echo("Message does not exist on aleph.im") + except ForgottenMessageError: + typer.echo("Message has been forgotten on aleph.im") + if message: + typer.echo(f"Message Status: {colorized_status(status)}") + if status == MessageStatus.REJECTED: + reason = await client.get_message_error(item_hash=ItemHash(item_hash)) + typer.echo(colorful_json(json.dumps(reason, indent=4))) + else: + typer.echo(colorful_message_json(message)) @app.command() @@ -173,41 +181,47 @@ async def amend( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) async with AlephHttpClient(api_server=settings.API_HOST) as client: - existing_message: AlephMessage = await client.get_message(item_hash=item_hash) - - editor: str = os.getenv("EDITOR", default="nano") - with tempfile.NamedTemporaryFile(suffix="json") as fd: - # Fill in message template - fd.write(existing_message.content.json(indent=4).encode()) - fd.seek(0) - - # Launch editor - subprocess.run([editor, fd.name], check=True) - - # Read new message - fd.seek(0) - new_content_json = fd.read() - - content_type = type(existing_message).__annotations__["content"] - new_content_dict = json.loads(new_content_json) - new_content = content_type(**new_content_dict) - - if isinstance(existing_message, ProgramMessage): - new_content.replaces = existing_message.item_hash - else: - new_content.ref = existing_message.item_hash - - new_content.time = time.time() - new_content.type = "amend" - - typer.echo(new_content) - async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: - message, status, response = await client.submit( - content=new_content.dict(), - message_type=existing_message.type, - channel=existing_message.channel, - ) - typer.echo(f"{message.json(indent=4)}") + existing_message: Optional[AlephMessage] = None + try: + existing_message = await client.get_message(item_hash=item_hash) + except MessageNotFoundError: + typer.echo("Message does not exist on aleph.im") + except ForgottenMessageError: + typer.echo("Message has been forgotten on aleph.im") + if existing_message: + editor: str = os.getenv("EDITOR", default="nano") + with tempfile.NamedTemporaryFile(suffix="json") as fd: + # Fill in message template + fd.write(existing_message.content.json(indent=4).encode()) + fd.seek(0) + + # Launch editor + subprocess.run([editor, fd.name], check=True) + + # Read new message + fd.seek(0) + new_content_json = fd.read() + + content_type = type(existing_message).__annotations__["content"] + new_content_dict = json.loads(new_content_json) + new_content = content_type(**new_content_dict) + + if isinstance(existing_message, ProgramMessage): + new_content.replaces = existing_message.item_hash + else: + new_content.ref = existing_message.item_hash + + new_content.time = time.time() + new_content.type = "amend" + + typer.echo(new_content) + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: + message, status, response = await client.submit( + content=new_content.dict(), + message_type=existing_message.type, + channel=existing_message.channel, + ) + typer.echo(f"{message.json(indent=4)}") @app.command() @@ -241,11 +255,18 @@ async def watch( setup_logging(debug) async with AlephHttpClient(api_server=settings.API_HOST) as client: - original: AlephMessage = await client.get_message(item_hash=ref) - async for message in client.watch_messages( - message_filter=MessageFilter(refs=[ref], addresses=[original.content.address]) - ): - typer.echo(f"{message.json(indent=indent)}") + original: Optional[AlephMessage] = None + try: + original = await client.get_message(item_hash=ref) + except MessageNotFoundError: + typer.echo("Message does not exist on aleph.im") + except ForgottenMessageError: + typer.echo("Message has been forgotten on aleph.im") + if original: + async for message in client.watch_messages( + message_filter=MessageFilter(refs=[ref], addresses=[original.content.address]) + ): + typer.echo(f"{message.json(indent=indent)}") @app.command() diff --git a/src/aleph_client/commands/node.py b/src/aleph_client/commands/node.py index c421e1a7..7f0c3a9b 100644 --- a/src/aleph_client/commands/node.py +++ b/src/aleph_client/commands/node.py @@ -9,18 +9,19 @@ import aiohttp import typer +from aleph.sdk.conf import settings from rich import text from rich.console import Console from rich.markup import escape from rich.table import Table from aleph_client.commands.utils import setup_logging -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AsyncTyper, sanitize_url logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) -node_link = "https://api2.aleph.im/api/v0/aggregates/0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10.json?keys=corechannel" +node_link = f"{sanitize_url(settings.API_HOST)}/api/v0/aggregates/0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10.json?keys=corechannel" class NodeInfo: @@ -78,6 +79,7 @@ def _show_compute(node_info): table.add_column("Decentralization", style="green", justify="right") table.add_column("Status", style="green", justify="right") table.add_column("Item Hash", style="green", justify="center") + table.add_column("URL", style="orchid", justify="center") for node in node_info.nodes: # Prevent escaping with name @@ -91,6 +93,7 @@ def _show_compute(node_info): score = _format_score(node["score"]) decentralization = _format_score(node["decentralization"]) status = _format_status(node["status"]) + node_url = node["address"] table.add_row( score, node_name, @@ -98,21 +101,42 @@ def _show_compute(node_info): decentralization, status, node_hash, + node_url, ) console = Console() console.print(table) -def _filter_node(active: bool, address: Optional[str], core_info): +def _filter_node( + active: bool, + address: Optional[str], + core_info, + payg_receiver=Optional[str], + crn_url=Optional[str], + crn_hash=Optional[str], + ccn_hash=Optional[str], +): result = [] + try: + node_url = not crn_url or sanitize_url(crn_url) + except Exception as e: + logger.debug(e) for node in core_info: - if active and node["status"] == "active" and node["score"] > 0: - result.append(node) - elif address and node["owner"] == address: - result.append(node) - elif not active and not address: - result.append(node) + try: + sanitized_url = node["address"] or sanitize_url(node["address"]) + if ( + (not active or (node["status"] == "linked" and node["score"] > 0)) + and (not address or node["owner"] == address) + and (not payg_receiver or node["stream_reward"] == payg_receiver) + and (not crn_url or (sanitized_url == node_url)) + and (not crn_hash or node["hash"] == crn_hash) + and (not ccn_hash or node["parent"] == ccn_hash) + ): + node["address"] = sanitized_url + result.append(node) + except Exception as e: + logger.debug(e) return result @@ -154,14 +178,28 @@ async def compute( json: bool = typer.Option(default=False, help="Print as json instead of rich table"), active: bool = typer.Option(default=False, help="Only show active nodes"), address: Optional[str] = typer.Option(default=None, help="Owner address to filter by"), + payg_receiver: Optional[str] = typer.Option( + default=None, help="PAYG (Pay-As-You-Go) receiver address to filter by" + ), + crn_url: Optional[str] = typer.Option(default=None, help="CRN Url to filter by"), + crn_hash: Optional[str] = typer.Option(default=None, help="CRN hash to filter by"), + ccn_hash: Optional[str] = typer.Option(default=None, help="Linked CCN hash to filter by"), debug: bool = False, ): - """Get all compute node on aleph network""" + """Get all compute node (CRN) on aleph network""" setup_logging(debug) compute_info: NodeInfo = await _fetch_nodes() - compute_info.nodes = _filter_node(core_info=compute_info.nodes, active=active, address=address) + compute_info.nodes = _filter_node( + core_info=compute_info.nodes, + active=active, + address=address, + payg_receiver=payg_receiver, + crn_url=crn_url, + crn_hash=crn_hash, + ccn_hash=ccn_hash, + ) if not json: _show_compute(compute_info) @@ -176,7 +214,7 @@ async def core( address: Optional[str] = typer.Option(default=None, help="Owner address to filter by"), debug: bool = False, ): - """Get all core node on aleph""" + """Get all core node (CCN) on aleph""" setup_logging(debug) core_info: NodeInfo = await _fetch_nodes() diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index d5034c0b..765775dd 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -2,32 +2,44 @@ import json import logging -import sys +import re from base64 import b16decode, b32encode from collections.abc import Mapping from pathlib import Path -from typing import Any, List, Optional, Tuple +from typing import List, Optional, cast from zipfile import BadZipFile +import aiohttp import typer -from aiohttp import ClientResponse -from aiohttp.client import _RequestContextManager -from aleph.sdk import AuthenticatedAlephHttpClient +from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient from aleph.sdk.account import _load_account from aleph.sdk.client.vm_client import VmClient from aleph.sdk.conf import settings +from aleph.sdk.exceptions import ForgottenMessageError, MessageNotFoundError +from aleph.sdk.query.filters import MessageFilter from aleph.sdk.types import AccountFromPrivateKey, StorageEnum -from aleph_message.models import Chain, ProgramMessage, StoreMessage +from aleph.sdk.utils import safe_getattr +from aleph_message.models import Chain, MessageType, ProgramMessage, StoreMessage from aleph_message.models.execution.program import ProgramContent from aleph_message.models.item_hash import ItemHash from aleph_message.status import MessageStatus from click import echo +from rich import box +from rich.console import Console +from rich.panel import Panel +from rich.prompt import Prompt +from rich.table import Table +from rich.text import Text from aleph_client.commands import help_strings from aleph_client.commands.utils import ( + filter_only_valid_messages, + get_or_prompt_environment_variables, get_or_prompt_volumes, input_multiline, setup_logging, + str_to_datetime, + validated_prompt, yes_no_input, ) from aleph_client.utils import AsyncTyper, create_archive, sanitize_url @@ -36,40 +48,52 @@ app = AsyncTyper(no_args_is_help=True) -@app.command() +@app.command(name="upload") +@app.command(name="create") async def upload( - path: Path = typer.Argument(..., help="Path to your source code"), - entrypoint: str = typer.Argument(..., help="Your program entrypoint"), + path: Path = typer.Argument(..., help=help_strings.PROGRAM_PATH), + entrypoint: str = typer.Argument( + ..., + help=help_strings.PROGRAM_ENTRYPOINT, + ), channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), - memory: int = typer.Option(settings.DEFAULT_VM_MEMORY, help="Maximum memory allocation on vm in MiB"), - vcpus: int = typer.Option(settings.DEFAULT_VM_VCPUS, help="Number of virtual cpus to allocate."), + memory: int = typer.Option(settings.DEFAULT_VM_MEMORY, help=help_strings.MEMORY), + vcpus: int = typer.Option(settings.DEFAULT_VM_VCPUS, help=help_strings.VCPUS), timeout_seconds: float = typer.Option( settings.DEFAULT_VM_TIMEOUT, - help="If vm is not called after [timeout_seconds] it will shutdown", + help=help_strings.TIMEOUT_SECONDS, ), - private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), - private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), - print_messages: bool = typer.Option(False), - print_code_message: bool = typer.Option(False), - print_program_message: bool = typer.Option(False), + name: Optional[str] = typer.Option(None, help="Name for your program"), runtime: str = typer.Option( None, - help="Hash of the runtime to use for your program. Defaults to aleph debian with Python3.8 and node. You can also create your own runtime and pin it", + help=help_strings.PROGRAM_RUNTIME.format(runtime_id=settings.DEFAULT_RUNTIME_ID), ), beta: bool = typer.Option( False, - help="If true, you will be prompted to add message subscriptions to your program", + help=help_strings.PROGRAM_BETA, ), - debug: bool = False, persistent: bool = False, + updatable: bool = typer.Option(False, help=help_strings.PROGRAM_UPDATABLE), + skip_volume: bool = typer.Option(False, help=help_strings.SKIP_VOLUME), persistent_volume: Optional[List[str]] = typer.Option(None, help=help_strings.PERSISTENT_VOLUME), ephemeral_volume: Optional[List[str]] = typer.Option(None, help=help_strings.EPHEMERAL_VOLUME), immutable_volume: Optional[List[str]] = typer.Option( None, help=help_strings.IMMUTABLE_VOLUME, ), -): - """Register a program to run on aleph.im. For more information, see https://docs.aleph.im/computing/""" + skip_env_var: bool = typer.Option(False, help=help_strings.SKIP_ENV_VAR), + env_vars: Optional[str] = typer.Option(None, help=help_strings.ENVIRONMENT_VARIABLES), + private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), + private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), + print_messages: bool = typer.Option(False), + print_code_message: bool = typer.Option(False), + print_program_message: bool = typer.Option(False), + verbose: bool = True, + debug: bool = False, +) -> Optional[str]: + """Register a program to run on aleph.im (create/upload are aliases) + + For more information, see https://docs.aleph.im/computing""" setup_logging(debug) @@ -79,23 +103,30 @@ async def upload( path_object, encoding = create_archive(path) except BadZipFile: typer.echo("Invalid zip archive") - raise typer.Exit(3) + raise typer.Exit(code=3) except FileNotFoundError: typer.echo("No such file or directory") - raise typer.Exit(4) + raise typer.Exit(code=4) account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - runtime = runtime or input(f"Ref of runtime ? [{settings.DEFAULT_RUNTIME_ID}] ") or settings.DEFAULT_RUNTIME_ID + name = name or validated_prompt("Program name", lambda x: len(x) < 65) + runtime = runtime or input(f"Ref of runtime? [{settings.DEFAULT_RUNTIME_ID}] ") or settings.DEFAULT_RUNTIME_ID - volumes = get_or_prompt_volumes( - persistent_volume=persistent_volume, - ephemeral_volume=ephemeral_volume, - immutable_volume=immutable_volume, - ) + volumes = [] + if not skip_volume: + volumes = get_or_prompt_volumes( + persistent_volume=persistent_volume, + ephemeral_volume=ephemeral_volume, + immutable_volume=immutable_volume, + ) + + environment_variables = None + if not skip_env_var: + environment_variables = get_or_prompt_environment_variables(env_vars) subscriptions: Optional[List[Mapping]] = None - if beta and yes_no_input("Subscribe to messages ?", default=False): + if beta and yes_no_input("Subscribe to messages?", default=False): content_raw = input_multiline() try: subscriptions = json.loads(content_raw) @@ -131,6 +162,8 @@ async def upload( message, status = await client.create_program( program_ref=program_ref, entrypoint=entrypoint, + metadata=dict(name=name), + allow_amend=updatable, runtime=runtime, storage_engine=StorageEnum.storage, channel=channel, @@ -140,6 +173,7 @@ async def upload( persistent=persistent, encoding=encoding, volumes=volumes, + environment_variables=environment_variables, subscriptions=subscriptions, ) logger.debug("Upload finished") @@ -147,57 +181,100 @@ async def upload( typer.echo(f"{message.json(indent=4)}") item_hash: ItemHash = message.item_hash - hash_base32 = b32encode(b16decode(item_hash.upper())).strip(b"=").lower().decode() - - typer.echo( - f"Your program has been uploaded on aleph.im\n\n" - "Available on:\n" - f" {settings.VM_URL_PATH.format(hash=item_hash)}\n" - f" {settings.VM_URL_HOST.format(hash_base32=hash_base32)}\n" - "Visualise on:\n https://explorer.aleph.im/address/" - f"{message.chain.value}/{message.sender}/message/PROGRAM/{item_hash}\n" - ) + if verbose: + hash_base32 = b32encode(b16decode(item_hash.upper())).strip(b"=").lower().decode() + func_url_1 = f"{settings.VM_URL_PATH.format(hash=item_hash)}" + func_url_2 = f"{settings.VM_URL_HOST.format(hash_base32=hash_base32)}" + + console = Console() + infos = [ + Text.from_markup(f"Your program [bright_cyan]{item_hash}[/bright_cyan] has been uploaded on aleph.im."), + Text.assemble( + "\n\nAvailable on:\n", + Text.from_markup( + f"↳ [bright_yellow][link={func_url_1}]{func_url_1}[/link][/bright_yellow]\n", + style="italic", + ), + Text.from_markup( + f"↳ [dark_olive_green2][link={func_url_2}]{func_url_2}[/link][/dark_olive_green2]", + style="italic", + ), + "\n\nVisualise on:\n", + Text.from_markup( + f"[blue]https://explorer.aleph.im/address/{message.chain.value}/{message.sender}/message/PROGRAM/{item_hash}[/blue]" + ), + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Program Created", + border_style="green", + expand=False, + title_align="left", + ) + ) + return item_hash @app.command() async def update( item_hash: str = typer.Argument(..., help="Item hash to update"), - path: Path = typer.Argument(..., help="Source path to upload"), + path: Path = typer.Argument(..., help=help_strings.PROGRAM_PATH), private_key: Optional[str] = settings.PRIVATE_KEY_STRING, private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, - print_message: bool = True, + print_message: bool = typer.Option(False), + verbose: bool = True, debug: bool = False, ): - """Update the code of an existing program""" + """Update the code of an existing program (item hash will not change)""" setup_logging(debug) - account = _load_account(private_key, private_key_file) path = path.absolute() - async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: - program_message: ProgramMessage = await client.get_message(item_hash=item_hash, message_type=ProgramMessage) - code_ref = program_message.content.code.ref - code_message: StoreMessage = await client.get_message(item_hash=code_ref, message_type=StoreMessage) + try: + path_object, encoding = create_archive(path) + except BadZipFile: + typer.echo("Invalid zip archive") + raise typer.Exit(code=3) + except FileNotFoundError: + typer.echo("No such file or directory") + raise typer.Exit(code=4) + + account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: try: - path, encoding = create_archive(path) - except BadZipFile: - typer.echo("Invalid zip archive") - raise typer.Exit(3) - except FileNotFoundError: - typer.echo("No such file or directory") - raise typer.Exit(4) + program_message: ProgramMessage = await client.get_message(item_hash=item_hash, message_type=ProgramMessage) + except MessageNotFoundError: + typer.echo("Program does not exist on aleph.im") + return 1 + except ForgottenMessageError: + typer.echo("Program has been deleted on aleph.im") + return 1 + if program_message.sender != account.get_address(): + typer.echo("You are not the owner of this program") + return 1 + code_ref = program_message.content.code.ref + try: + code_message: StoreMessage = await client.get_message(item_hash=code_ref, message_type=StoreMessage) + except MessageNotFoundError: + typer.echo("Code volume does not exist on aleph.im") + return 1 + except ForgottenMessageError: + typer.echo("Code volume has been deleted on aleph.im") + return 1 if encoding != program_message.content.code.encoding: logger.error( f"Code must be encoded with the same encoding as the previous version " f"('{encoding}' vs '{program_message.content.code.encoding}'" ) - raise typer.Exit(1) + return 1 - # Upload the source code - with open(path, "rb") as fd: + # Upload the new source code + with open(path_object, "rb") as fd: logger.debug("Reading file") # TODO: Read in lazy mode instead of copying everything in memory file_content = fd.read() @@ -214,24 +291,345 @@ async def update( if print_message: typer.echo(f"{message.json(indent=4)}") + if verbose: + hash_base32 = b32encode(b16decode(item_hash.upper())).strip(b"=").lower().decode() + func_url_1 = f"{settings.VM_URL_PATH.format(hash=item_hash)}" + func_url_2 = f"{settings.VM_URL_HOST.format(hash_base32=hash_base32)}" + console = Console() + infos = [ + Text.from_markup( + f"Your program [bright_cyan]{item_hash}[/bright_cyan] has been updated to the new source code." + ), + Text.from_markup(f"\n\nUpdated code volume: [orange3]{code_message.item_hash}[/orange3]"), + Text.assemble( + "\n\nAvailable on:\n", + Text.from_markup( + f"↳ [bright_yellow][link={func_url_1}]{func_url_1}[/link][/bright_yellow]\n", + style="italic", + ), + Text.from_markup( + f"↳ [dark_olive_green2][link={func_url_2}]{func_url_2}[/link][/dark_olive_green2]", + style="italic", + ), + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Program Updated", + border_style="orange3", + expand=False, + title_align="left", + ) + ) + @app.command() -async def unpersist( +async def delete( item_hash: str = typer.Argument(..., help="Item hash to unpersist"), + reason: str = typer.Option("User deletion", help="Reason for deleting the program"), + keep_code: bool = typer.Option(False, help=help_strings.PROGRAM_KEEP_CODE), private_key: Optional[str] = settings.PRIVATE_KEY_STRING, private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, + print_message: bool = typer.Option(False), + verbose: bool = True, debug: bool = False, ): - """Stop a persistent virtual machine by making it non-persistent""" + """Delete a program""" + + setup_logging(debug) + + account = _load_account(private_key, private_key_file) + + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: + try: + existing_message: ProgramMessage = await client.get_message( + item_hash=item_hash, message_type=ProgramMessage + ) + except MessageNotFoundError: + typer.echo("Program does not exist on aleph.im") + return 1 + except ForgottenMessageError: + typer.echo("Program has been already deleted on aleph.im") + return 1 + if existing_message.sender != account.get_address(): + typer.echo("You are not the owner of this program") + return 1 + + message, _ = await client.forget(hashes=[ItemHash(item_hash)], reason=reason) + if not keep_code: + try: + code_volume: StoreMessage = await client.get_message( + item_hash=existing_message.content.code.ref, message_type=StoreMessage + ) + except MessageNotFoundError: + typer.echo("Code volume does not exist. Skipping...") + return 1 + except ForgottenMessageError: + typer.echo("Code volume has been already deleted. Skipping...") + return 1 + if existing_message.sender != account.get_address(): + typer.echo("You are not the owner of this code volume. Skipping...") + return 1 + + code_message, _ = await client.forget( + hashes=[ItemHash(code_volume.item_hash)], reason=f"Deletion of program {item_hash}" + ) + if verbose: + typer.echo(f"Code volume {code_volume.item_hash} has been deleted.") + if print_message: + typer.echo(f"{message.json(indent=4)}") + if verbose: + typer.echo(f"Program {item_hash} has been deleted.") + + +@app.command(name="list") +async def list_programs( + address: Optional[str] = typer.Option(None, help="Owner address of the programs"), + private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), + private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), + json: bool = typer.Option(default=False, help="Print as json instead of rich table"), + debug: bool = False, +): + """List all programs associated to an account""" + + setup_logging(debug) + + if address is None: + account = _load_account(private_key, private_key_file) + address = account.get_address() + + async with AlephHttpClient(api_server=settings.API_HOST) as client: + resp = await client.get_messages( + message_filter=MessageFilter( + message_types=[MessageType.program], + addresses=[address], + ), + page_size=100, + ) + messages = await filter_only_valid_messages(resp.messages) + if not messages: + typer.echo(f"Address: {address}\n\nNo program found\n") + raise typer.Exit(code=1) + if json: + for message in messages: + typer.echo(message.json(indent=4)) + else: + # Since we filtered on message type, we can safely cast as ProgramMessage. + messages = cast(List[ProgramMessage], messages) + + table = Table(box=box.ROUNDED, style="blue_violet") + table.add_column(f"Programs [{len(messages)}]", style="blue", overflow="fold") + table.add_column("Specifications", style="blue") + table.add_column("Configurations", style="blue", overflow="fold") + + for message in messages: + name = Text( + ( + message.content.metadata["name"] + if hasattr(message.content, "metadata") + and isinstance(message.content.metadata, dict) + and "name" in message.content.metadata + else "-" + ), + style="magenta3", + ) + msg_link = f"https://explorer.aleph.im/address/ETH/{message.sender}/message/PROGRAM/{message.item_hash}" + item_hash_link = Text.from_markup(f"[link={msg_link}]{message.item_hash}[/link]", style="bright_cyan") + created_at = Text.assemble( + "URLs ↓\t Created at: ", + Text( + str(str_to_datetime(str(safe_getattr(message, "content.time")))).split(".", maxsplit=1)[0], + style="orchid", + ), + ) + hash_base32 = b32encode(b16decode(message.item_hash.upper())).strip(b"=").lower().decode() + func_url_1 = settings.VM_URL_PATH.format(hash=message.item_hash) + func_url_2 = settings.VM_URL_HOST.format(hash_base32=hash_base32) + urls = Text.from_markup( + f"[bright_yellow][link={func_url_1}]{func_url_1}[/link][/bright_yellow]\n[dark_olive_green2][link={func_url_2}]{func_url_2}[/link][/dark_olive_green2]" + ) + program = Text.assemble( + "Item Hash ↓\t Name: ", name, "\n", item_hash_link, "\n", created_at, "\n", urls + ) + specs = [ + f"vCPU: [magenta3]{message.content.resources.vcpus}[/magenta3]\n", + f"RAM: [magenta3]{message.content.resources.memory / 1_024:.2f} GiB[/magenta3]\n", + "HyperV: [magenta3]Firecracker[/magenta3]\n", + f"Timeout: [orange3]{message.content.resources.seconds}s[/orange3]\n", + f"Persistent: {'[green]Yes[/green]' if message.content.on.persistent else '[red]No[/red]'}\n", + f"Updatable: {'[green]Yes[/green]' if message.content.allow_amend else '[red]No[/red]'}", + ] + specifications = Text.from_markup("".join(specs)) + volumes = "" + for volume in message.content.volumes: + if safe_getattr(volume, "ref"): + volumes += f"\n• [orchid]{volume.mount}[/orchid]: [bright_cyan][link={settings.API_HOST}/api/v0/messages/{volume.ref}]{volume.ref}[/link][/bright_cyan]" + elif safe_getattr(volume, "ephemeral"): + volumes += f"\n• [orchid]{volume.mount}[/orchid]: [bright_red]ephemeral[/bright_red]" + else: + volumes += f"\n• [orchid]{volume.mount}[/orchid]: [orange3]persistent on {volume.persistence.value}[/orange3]" + config = Text.assemble( + Text.from_markup( + f"Runtime: [bright_cyan][link={settings.API_HOST}/api/v0/messages/{message.content.runtime.ref}]{message.content.runtime.ref}[/link][/bright_cyan]\n" + f"Code: [bright_cyan][link={settings.API_HOST}/api/v0/messages/{message.content.code.ref}]{message.content.code.ref}[/link][/bright_cyan]\n" + f"↳ Entrypoint: [orchid]{message.content.code.entrypoint}[/orchid]\n" + ), + Text.from_markup(f"Mounted Volumes: {volumes if volumes else '-'}"), + ) + table.add_row(program, specifications, config) + table.add_section() + + console = Console() + console.print(table) + infos = [ + Text.from_markup( + f"[bold]Address:[/bold] [bright_cyan]{messages[0].content.address}[/bright_cyan]\n\nTo access any program's logs, use:\n" + ), + Text.from_markup( + "↳ aleph program logs [bright_cyan][/bright_cyan] --domain [orchid][/orchid]", + style="italic", + ), + ] + console.print( + Panel( + Text.assemble(*infos), title="Infos", border_style="bright_cyan", expand=False, title_align="left" + ) + ) + + +@app.command() +async def persist( + item_hash: str = typer.Argument(..., help="Item hash to persist"), + keep_prev: bool = typer.Option( + False, + help=help_strings.PROGRAM_KEEP_PREV, + ), + private_key: Optional[str] = settings.PRIVATE_KEY_STRING, + private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, + print_message: bool = typer.Option(False), + verbose: bool = True, + debug: bool = False, +) -> Optional[str]: + """Recreate a non-persistent program as persistent (item hash will change)""" setup_logging(debug) account = _load_account(private_key, private_key_file) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: - message: ProgramMessage = await client.get_message(item_hash=item_hash, message_type=ProgramMessage) + try: + message: ProgramMessage = await client.get_message(item_hash=item_hash, message_type=ProgramMessage) + except MessageNotFoundError: + typer.echo("Program does not exist on aleph.im") + return None + except ForgottenMessageError: + typer.echo("Program has been deleted on aleph.im") + return None + if message.sender != account.get_address(): + typer.echo("You are not the owner of this program") + return None + if not message.content.allow_amend: + typer.echo("Program is not updatable") + return None + if message.content.on.persistent: + typer.echo("Program is already persistent") + return None + + # Update content content: ProgramContent = message.content.copy() + content.on.persistent = True + content.replaces = message.item_hash + + message, _status, _ = await client.submit( + content=content.dict(exclude_none=True), + message_type=message.type, + channel=message.channel, + ) + + if print_message: + typer.echo(f"{message.json(indent=4)}") + + # Delete previous non-persistent program + prev_label, prev_color = "INTACT", "orange3" + if not keep_prev: + await client.forget(hashes=[ItemHash(item_hash)], reason="Program persisted") + prev_label, prev_color = "DELETED", "red" + + if verbose: + hash_base32 = b32encode(b16decode(item_hash.upper())).strip(b"=").lower().decode() + func_url_1 = f"{settings.VM_URL_PATH.format(hash=item_hash)}" + func_url_2 = f"{settings.VM_URL_HOST.format(hash_base32=hash_base32)}" + console = Console() + infos = [ + Text.from_markup("Your program is now [green]persistent[/green]. It implies a new item hash."), + Text.from_markup( + f"\n\n[{prev_color}]- Prev non-persistent program: {item_hash} -> {prev_label}[/{prev_color}]\n[green]- New persistent program: {message.item_hash}[/green]." + ), + Text.assemble( + "\n\nAvailable on:\n", + Text.from_markup( + f"↳ [bright_yellow][link={func_url_1}]{func_url_1}[/link][/bright_yellow]\n", + style="italic", + ), + Text.from_markup( + f"↳ [dark_olive_green2][link={func_url_2}]{func_url_2}[/link][/dark_olive_green2]", + style="italic", + ), + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Program: Persist", + border_style="orchid", + expand=False, + title_align="left", + ) + ) + return message.item_hash + + +@app.command() +async def unpersist( + item_hash: str = typer.Argument(..., help="Item hash to unpersist"), + keep_prev: bool = typer.Option( + False, + help=help_strings.PROGRAM_KEEP_PREV, + ), + private_key: Optional[str] = settings.PRIVATE_KEY_STRING, + private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, + print_message: bool = typer.Option(False), + verbose: bool = True, + debug: bool = False, +) -> Optional[str]: + """Recreate a persistent program as non-persistent (item hash will change)""" + + setup_logging(debug) + account = _load_account(private_key, private_key_file) + + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: + try: + message: ProgramMessage = await client.get_message(item_hash=item_hash, message_type=ProgramMessage) + except MessageNotFoundError: + typer.echo("Program does not exist on aleph.im") + return None + except ForgottenMessageError: + typer.echo("Program has been deleted on aleph.im") + return None + if message.sender != account.get_address(): + typer.echo("You are not the owner of this program") + return None + if not message.content.allow_amend: + typer.echo("Program is not updatable") + return None + if not message.content.on.persistent: + typer.echo("Program is already unpersistent") + return None + + # Update content + content: ProgramContent = message.content.copy() content.on.persistent = False content.replaces = message.item_hash @@ -240,7 +638,48 @@ async def unpersist( message_type=message.type, channel=message.channel, ) - typer.echo(f"{message.json(indent=4)}") + + if print_message: + typer.echo(f"{message.json(indent=4)}") + + # Delete previous persistent program + prev_label, prev_color = "INTACT", "orange3" + if not keep_prev: + await client.forget(hashes=[ItemHash(item_hash)], reason="Program unpersisted") + prev_label, prev_color = "DELETED", "red" + + if verbose: + hash_base32 = b32encode(b16decode(item_hash.upper())).strip(b"=").lower().decode() + func_url_1 = f"{settings.VM_URL_PATH.format(hash=item_hash)}" + func_url_2 = f"{settings.VM_URL_HOST.format(hash_base32=hash_base32)}" + console = Console() + infos = [ + Text.from_markup("Your program is now [red]unpersistent[/red]. It implies a new item hash."), + Text.from_markup( + f"\n\n[{prev_color}]- Prev persistent program: {item_hash} -> {prev_label}[/{prev_color}]\n[green]- New non-persistent program: {message.item_hash}[/green]." + ), + Text.assemble( + "\n\nAvailable on:\n", + Text.from_markup( + f"↳ [bright_yellow][link={func_url_1}]{func_url_1}[/link][/bright_yellow]\n", + style="italic", + ), + Text.from_markup( + f"↳ [dark_olive_green2][link={func_url_2}]{func_url_2}[/link][/dark_olive_green2]", + style="italic", + ), + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Program: Unpersist", + border_style="orchid", + expand=False, + title_align="left", + ) + ) + return message.item_hash @app.command() @@ -248,18 +687,18 @@ async def logs( item_hash: str = typer.Argument(..., help="Item hash of program"), private_key: Optional[str] = settings.PRIVATE_KEY_STRING, private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, - domain: str = typer.Option(None, help="CRN domain on which the VM is stored or running"), + domain: str = typer.Option(None, help=help_strings.PROMPT_PROGRAM_CRN_URL), chain: Chain = typer.Option(None, help=help_strings.ADDRESS_CHAIN), debug: bool = False, ): - """Display logs for the program. + """Display the logs of a program - Will only show logs from one selected CRN""" + Will only show logs from the selected CRN""" setup_logging(debug) account = _load_account(private_key, private_key_file, chain=chain) - domain = sanitize_url(domain) + domain = sanitize_url(domain or Prompt.ask(help_strings.PROMPT_PROGRAM_CRN_URL)) async with VmClient(account, domain) as client: async with client.operate(vm_id=item_hash, operation="logs", method="GET") as response: @@ -269,17 +708,96 @@ async def logs( logger.debug(await response.text()) if response.status == 404: - echo(f"Server didn't found any execution of this prorgam") + echo(f"Server didn't found any execution of this program") return 1 elif response.status == 403: - echo(f"You are not the owner of this VM. Maybe try with another wallet?") - return 1 elif response.status != 200: - echo(f"Server error: {response.status}. Please try again latter") + echo(f"Server error: {response.status}. Please try again later") return 1 echo("Received logs") log_entries = await response.json() for log in log_entries: echo(f'{log["__REALTIME_TIMESTAMP"]}> {log["MESSAGE"]}') + + +@app.command() +async def runtime_checker( + item_hash: str = typer.Argument(..., help="Item hash of the runtime to check"), + private_key: Optional[str] = settings.PRIVATE_KEY_STRING, + private_key_file: Optional[Path] = settings.PRIVATE_KEY_FILE, + verbose: bool = False, + debug: bool = False, +): + """Check versions used by a runtime (distribution, python, nodejs, etc)""" + + setup_logging(debug) + + echo("Deploy runtime checker program...") + try: + program_hash = await upload( + path=Path(__file__).resolve().parent / "program_utils/runtime_checker.squashfs", + entrypoint="main:app", + channel=settings.DEFAULT_CHANNEL, + memory=settings.DEFAULT_VM_MEMORY, + vcpus=settings.DEFAULT_VM_VCPUS, + timeout_seconds=settings.DEFAULT_VM_TIMEOUT, + name="runtime_checker", + runtime=item_hash, + beta=False, + persistent=False, + updatable=False, + skip_volume=True, + skip_env_var=True, + private_key=private_key, + private_key_file=private_key_file, + print_messages=False, + print_code_message=False, + print_program_message=False, + verbose=verbose, + debug=debug, + ) + if not program_hash: + raise Exception("No program hash") + except Exception as e: + echo(f"Failed to deploy the runtime checker program: {e}") + raise typer.Exit(code=1) + + program_url = settings.VM_URL_PATH.format(hash=program_hash) + versions: dict + echo("Query runtime checker to retrieve versions...") + try: + timeout = aiohttp.ClientTimeout(total=settings.HTTP_REQUEST_TIMEOUT) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(program_url) as resp: + resp.raise_for_status() + versions = await resp.json() + except Exception as e: + logger.debug(f"Unexpected error when calling {program_url}: {e}") + raise typer.Exit(code=1) + + echo("Delete runtime checker...") + try: + await delete( + item_hash=program_hash, + reason="Automatic deletion of the runtime checker program", + keep_code=True, + private_key=private_key, + private_key_file=private_key_file, + print_message=False, + verbose=verbose, + debug=debug, + ) + except Exception as e: + echo(f"Failed to delete the runtime checker program: {e}") + raise typer.Exit(code=1) + + console = Console() + infos = [Text.from_markup(f"[bold]Ref:[/bold] [bright_cyan]{item_hash}[/bright_cyan]")] + for label, version in versions.items(): + color = "green" if bool(re.search(r"\d", version)) else "red" + infos.append(Text.from_markup(f"\n[bold]{label}:[/bold] [{color}]{version}[/{color}]")) + console.print( + Panel(Text.assemble(*infos), title="Runtime Infos", border_style="violet", expand=False, title_align="left") + ) diff --git a/src/aleph_client/commands/program_utils/runtime_checker.squashfs b/src/aleph_client/commands/program_utils/runtime_checker.squashfs new file mode 100644 index 00000000..59be5e12 Binary files /dev/null and b/src/aleph_client/commands/program_utils/runtime_checker.squashfs differ diff --git a/src/aleph_client/commands/program_utils/runtime_checker/main.py b/src/aleph_client/commands/program_utils/runtime_checker/main.py new file mode 100644 index 00000000..def71810 --- /dev/null +++ b/src/aleph_client/commands/program_utils/runtime_checker/main.py @@ -0,0 +1,37 @@ +import platform +import subprocess +from typing import Dict + +from fastapi import FastAPI + +app = FastAPI() + +extra_checks = dict( + Docker="docker --version", + Nodejs="node --version", + Rust="rustc --version", + Go="go version", +) + + +@app.get("/") +async def versions() -> Dict[str, str]: + results = dict() + + # Distribution + try: + results["Distribution"] = platform.freedesktop_os_release()["PRETTY_NAME"] # type: ignore + except Exception: + results["Distribution"] = "Not available" + + # Python + results["Python"] = platform.python_version() + + # Others + for label, command in extra_checks.items(): + try: + results[label] = subprocess.check_output(command.split(" ")).decode("utf-8").strip() + except Exception: + results[label] = "Not installed" + + return results diff --git a/src/aleph_client/commands/utils.py b/src/aleph_client/commands/utils.py index 6e4eeb75..e015cb70 100644 --- a/src/aleph_client/commands/utils.py +++ b/src/aleph_client/commands/utils.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import shutil import sys from datetime import datetime from pathlib import Path @@ -20,7 +21,7 @@ from pygments.formatters.terminal256 import Terminal256Formatter from pygments.lexers import JsonLexer from rich.prompt import IntPrompt, Prompt, PromptError -from typer import colors, echo, style +from typer import Exit, colors, echo, style from aleph_client.utils import fetch_json @@ -79,7 +80,7 @@ def yes_no_input(text: str, default: str | bool) -> bool: def prompt_for_volumes(): - while yes_no_input("Add volume ?", default=False): + while yes_no_input("Add volume?", default=False): mount = validated_prompt("Mount path (ex: /opt/data): ", lambda text: len(text) > 0) name = validated_prompt("Name: ", lambda text: len(text) > 0) comment = Prompt.ask("Comment: ") @@ -95,7 +96,7 @@ def prompt_for_volumes(): } else: ref = validated_prompt("Item hash: ", lambda text: len(text) == 64) - use_latest = yes_no_input("Use latest version ?", default=True) + use_latest = yes_no_input("Use latest version?", default=True) yield { "comment": comment, "mount": mount, @@ -150,6 +151,27 @@ def get_or_prompt_volumes(ephemeral_volume, immutable_volume, persistent_volume) return volumes +def env_vars_to_dict(env_vars: Optional[str]) -> Dict[str, str]: + dict_store: Dict[str, str] = {} + if env_vars: + for env_var in env_vars.split(","): + label, value = env_var.split("=", 1) + dict_store[label.strip()] = value.strip() + return dict_store + + +def get_or_prompt_environment_variables(env_vars: Optional[str]) -> Optional[Dict[str, str]]: + environment_variables: Dict[str, str] = {} + if not env_vars: + while yes_no_input("Add environment variable?", default=False): + label = validated_prompt("Label: ", lambda text: len(text) > 0) + value = validated_prompt("Value: ", lambda text: len(text) > 0) + environment_variables[label] = value + else: + environment_variables = env_vars_to_dict(env_vars) + return environment_variables if environment_variables else None + + def str_to_datetime(date: Optional[str]) -> Optional[datetime]: """ Converts a string representation of a date/time to a datetime object. @@ -239,14 +261,6 @@ def is_environment_interactive() -> bool: ) -def safe_getattr(obj, attr, default=None): - for part in attr.split("."): - obj = getattr(obj, part, default) - if obj is default: - break - return obj - - async def wait_for_processed_instance(session: ClientSession, item_hash: ItemHash): """Wait for a message to be processed by CCN""" while True: @@ -297,3 +311,13 @@ def validate_ssh_pubkey_file(file: Union[str, Path]) -> Path: if not file.is_file(): raise ValueError(f"{file} is not a file") return file + + +def find_sevctl_or_exit() -> Path: + "Find sevctl in path, exit with message if not available" + sevctl_path = shutil.which("sevctl") + if sevctl_path is None: + echo("sevctl binary is not available. Please install sevctl, ensure it is in the PATH and try again.") + echo("Instructions for setup https://docs.aleph.im/computing/confidential/requirements/") + raise Exit(code=1) + return Path(sevctl_path) diff --git a/src/aleph_client/models.py b/src/aleph_client/models.py index 9bf8d3f9..b8bbe893 100644 --- a/src/aleph_client/models.py +++ b/src/aleph_client/models.py @@ -1,10 +1,11 @@ from datetime import datetime -from enum import Enum from typing import List, Optional from aleph_message.models import ItemHash -from aleph_message.models.execution.environment import CpuProperties +from aleph_message.models.execution.environment import CpuProperties, GpuDeviceClass from pydantic import BaseModel +from rich.console import Console +from rich.panel import Panel from typer import echo from aleph_client.commands.node import _escape_and_normalize, _remove_ansi_escape @@ -49,7 +50,7 @@ class MachineProperties(BaseModel): class GpuDevice(BaseModel): vendor: str device_name: str - device_class: str + device_class: GpuDeviceClass pci_host: str device_id: str @@ -150,16 +151,28 @@ def display_hdd(self) -> str: return "" def display_crn_specs(self): - echo(f"Hash: {self.hash}") - echo(f"Name: {self.name}") - echo(f"URL: {self.url}") - echo(f"Version: {self.version}") - echo(f"Score: {self.score}") - echo(f"Stream receiver: {self.stream_reward_address}") - if isinstance(self.machine_usage, MachineUsage): - echo(f"Available Cores: {self.display_cpu}") - echo(f"Available RAM: {self.display_ram}") - echo(f"Available Disk: {self.display_hdd}") - echo(f"Support Qemu: {self.qemu_support}") - echo(f"Support Confidential: {self.confidential_computing}") - echo(f"Support GPU: {self.gpu_support}") + console = Console() + + data = { + "Hash": self.hash, + "Name": self.name, + "URL": self.url, + "Version": self.version, + "Score": self.score, + "Stream Receiver": self.stream_reward_address, + **( + { + "Available Cores": self.display_cpu, + "Available RAM": self.display_ram, + "Available Disk": self.display_hdd, + } + if isinstance(self.machine_usage, MachineUsage) + else {} + ), + "Support Qemu": self.qemu_support, + "Support Confidential": self.confidential_computing, + "Support GPU": self.gpu_support, + } + text = "\n".join(f"[orange3]{key}[/orange3]: {value}" for key, value in data.items()) + + console.print(Panel(text, title="Selected CRN", border_style="bright_cyan", expand=False, title_align="left")) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3de3007a..b2bb960c 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,7 +8,7 @@ """ from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Generator +from typing import Generator, Tuple import pytest from aleph.sdk.chains.common import generate_key @@ -27,7 +27,7 @@ def empty_account_file() -> Generator[Path, None, None]: @pytest.fixture -def env_files(new_config_file: Path, empty_account_file: Path) -> Generator[Path, None, None]: +def env_files(new_config_file: Path, empty_account_file: Path) -> Generator[Tuple[Path, Path], None, None]: new_config_file.write_text(f'{{"path": "{empty_account_file}", "chain": "ETH"}}') empty_account_file.write_bytes(generate_key()) yield empty_account_file, new_config_file diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py new file mode 100644 index 00000000..7d132dc6 --- /dev/null +++ b/tests/unit/mocks.py @@ -0,0 +1,45 @@ +from decimal import Decimal +from unittest.mock import AsyncMock, MagicMock + +from aleph.sdk.chains.evm import EVMAccount +from aleph.sdk.conf import settings +from eth_utils.currency import to_wei +from pydantic import BaseModel + +# Change to Aleph testnet +settings.API_HOST = "https://api.twentysix.testnet.network" + +# Utils +FAKE_PRIVATE_KEY = b"cafe" * 8 +FAKE_PUBKEY_FILE = "/path/fake/pubkey" +FAKE_ADDRESS_EVM = "0x00001A0e6B9a46Be48a294D74D897d9C48678862" +FAKE_STORE_HASH = "102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a" # Has to exist on Aleph Testnet +FAKE_STORE_HASH_CONTENT_FILE_CID = "QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW" # From FAKE_STORE_HASH message +FAKE_VM_HASH = "ab12" * 16 +FAKE_PROGRAM_HASH = "cd34" * 16 +FAKE_PROGRAM_HASH_2 = "ef56" * 16 +FAKE_CRN_HASH = "2cdb78cf561c6f0f839edb817395d3b5ece20d89125c5afba658f9170d6932c8" +FAKE_CRN_URL = "https://dchq.staging.aleph.sh" +FAKE_FLOW_HASH = "0xfake_flow_hash" + + +class Dict(BaseModel): + class Config: + extra = "allow" + + +def create_test_account() -> EVMAccount: + return EVMAccount(private_key=FAKE_PRIVATE_KEY) + + +def create_mock_load_account(): + mock_account = create_test_account() + mock_loader = MagicMock(return_value=mock_account) + mock_loader.return_value.get_super_token_balance = MagicMock(return_value=Decimal(10000 * (10**18))) + mock_loader.return_value.can_transact = MagicMock(return_value=True) + mock_loader.return_value.superfluid_connector = MagicMock(can_start_flow=MagicMock(return_value=True)) + mock_loader.return_value.get_flow = AsyncMock(return_value={"flowRate": to_wei(0.0001, unit="ether")}) + mock_loader.return_value.create_flow = AsyncMock(return_value=FAKE_FLOW_HASH) + mock_loader.return_value.update_flow = AsyncMock(return_value=FAKE_FLOW_HASH) + mock_loader.return_value.delete_flow = AsyncMock(return_value=FAKE_FLOW_HASH) + return mock_loader diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index f04ff4da..369d338a 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -1,8 +1,7 @@ -import contextlib import json +import os from pathlib import Path from tempfile import NamedTemporaryFile -from unittest.mock import AsyncMock, patch from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.conf import settings @@ -10,8 +9,9 @@ from aleph_client.__main__ import app +from .mocks import FAKE_STORE_HASH, FAKE_STORE_HASH_CONTENT_FILE_CID + runner = CliRunner() -settings.API_HOST = "https://api.twentysix.testnet.network" def get_account(my_account_file: Path) -> ETHAccount: @@ -139,9 +139,7 @@ def test_account_balance(env_files): app, ["account", "balance", "--address", "0xCAfEcAfeCAfECaFeCaFecaFecaFECafECafeCaFe", "--chain", "ETH"] ) assert result.exit_code == 0 - assert result.stdout.startswith( - "Failed to retrieve balance for address 0xCAfEcAfeCAfECaFeCaFecaFecaFECafECafeCaFe. Status code: 404" - ) + assert result.stdout.startswith("╭─ Account Infos") def test_account_config(env_files): @@ -285,48 +283,32 @@ def test_file_upload(): def test_file_download(): - # Test download a file to aleph network + # Test download a file from aleph network + ipfs_cid = "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH" result = runner.invoke( app, [ "file", "download", - "QmeomffUNfmQy76CQGy9NdmqEnnHU9soCexBnGU3ezPHVH", + ipfs_cid, ], # 5 bytes file ) assert result.exit_code == 0 assert result.stdout is not None + os.remove(ipfs_cid) -def test_app(): - @contextlib.asynccontextmanager - async def m(self, vm_id, operation, method="GET"): - try: - yield AsyncMock( - url="http://", - status=200, - json=AsyncMock( - return_value=[ - { - "__REALTIME_TIMESTAMP": "2024-02-02 23:34:21", - "MESSAGE": "hello world", - } - ] - ), - ) - finally: - pass - - with patch("aleph_client.commands.program.VmClient.operate", m): - result = runner.invoke( - app, - [ - "program", - "logs", - "--domain", - "http://localhost:4200", - "decadecadecadecadecadecadecadecadecadecadecadecadecadecadecadeca", - ], - ) - assert result.exit_code == 0, result.stdout - assert result.stdout == "Received logs\n2024-02-02 23:34:21> hello world\n" +def test_file_download_only_info(): + # Test retrieve the underlying content cid + result = runner.invoke( + app, + [ + "file", + "download", + FAKE_STORE_HASH, + "--only-info", + ], + standalone_mode=False, + ) + assert result.exit_code == 0 + assert result.return_value.dict()["hash"] == FAKE_STORE_HASH_CONTENT_FILE_CID diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index c96b753e..ff60714b 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1,23 +1,45 @@ from __future__ import annotations +import asyncio +import random from datetime import datetime, timezone +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from aiohttp import InvalidURL -from aleph.sdk.chains.ethereum import ETHAccount -from aleph_message.models import Chain +from aleph.sdk.conf import settings +from aleph_message.models import Chain, ItemHash from aleph_message.models.execution.base import Payment, PaymentType -from aleph_message.models.execution.environment import CpuProperties -from eth_utils.currency import to_wei +from aleph_message.models.execution.environment import ( + CpuProperties, + GpuDeviceClass, + HypervisorType, + MachineResources, +) from multidict import CIMultiDict, CIMultiDictProxy -from aleph_client.commands.instance import delete +from aleph_client.commands import help_strings +from aleph_client.commands.instance import ( + allocate, + confidential_create, + confidential_init_session, + confidential_start, + create, + delete, + list_instances, + logs, + reboot, + stop, +) from aleph_client.commands.instance.network import fetch_crn_info from aleph_client.models import ( CoreFrequencies, CpuUsage, + CRNInfo, DiskUsage, + GpuDevice, + GPUProperties, LoadAverage, MachineInfo, MachineProperties, @@ -27,11 +49,39 @@ ) from aleph_client.utils import FORBIDDEN_HOSTS, sanitize_url +from .mocks import ( + FAKE_ADDRESS_EVM, + FAKE_CRN_HASH, + FAKE_CRN_URL, + FAKE_PUBKEY_FILE, + FAKE_STORE_HASH, + FAKE_VM_HASH, + Dict, + create_mock_load_account, +) + + +def dummy_gpu_device() -> GpuDevice: + return GpuDevice( + vendor="NVIDIA", + device_name="RTX 4090", + device_class=GpuDeviceClass.VGA_COMPATIBLE_CONTROLLER, + pci_host="01:00.0", + device_id="abcd:1234", + ) + def dummy_machine_info() -> MachineInfo: """Create a dummy MachineInfo object for testing purposes.""" + + gpu_devices = [dummy_gpu_device()] return MachineInfo( - hash="blalba", + hash=FAKE_CRN_HASH, + name="Mock CRN", + url="https://example.com", + version="v420.69", + score=0.5, + reward_address=FAKE_ADDRESS_EVM, machine_usage=MachineUsage( cpu=CpuUsage( count=8, @@ -39,12 +89,12 @@ def dummy_machine_info() -> MachineInfo: core_frequencies=CoreFrequencies(min=1.0, max=2.0), ), mem=MemoryUsage( - total_kB=1_000_000, - available_kB=500_000, + total_kB=32_000_000, + available_kB=28_000_000, ), disk=DiskUsage( - total_kB=1_000_000, - available_kB=500_000, + total_kB=1_000_000_000, + available_kB=500_000_000, ), period=UsagePeriod( start_timestamp=datetime.now(tz=timezone.utc), @@ -56,13 +106,29 @@ def dummy_machine_info() -> MachineInfo: vendor="AuthenticAMD", ), ), - gpu=None, + gpu=GPUProperties( + devices=gpu_devices, + available_devices=gpu_devices, + ), ), - score=0.5, - name="CRN", - version="0.0.1", - reward_address="0xcafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", - url="https://example.com", + ) + + +def create_mock_crn_info(): + mock_machine_info = dummy_machine_info() + return MagicMock( + return_value=CRNInfo( + hash=ItemHash(FAKE_CRN_HASH), + name="Mock CRN", + url=FAKE_CRN_URL, + version="v420.69", + score=0.5, + stream_reward_address=mock_machine_info.reward_address, + machine_usage=mock_machine_info.machine_usage, + qemu_support=True, + confidential_computing=True, + gpu_support=True, + ) ) @@ -113,51 +179,617 @@ def test_sanitize_url_with_https_scheme(): assert sanitize_url(url) == url -class MockETHAccount(ETHAccount): - pass +def create_mock_instance_message(mock_account, payg=False, coco=False, gpu=False): + tmp = list(FAKE_VM_HASH) + random.shuffle(tmp) + vm_item_hash = "".join(tmp) + vm = Dict( + chain=Chain.ETH, + sender=mock_account.get_address(), + type="instance", + channel="ALEPH-CLOUDSOLUTIONS", + confirmed=True, + item_type="inline", + item_hash=vm_item_hash, + content=Dict( + address=mock_account.get_address(), + time=1734037086.2333803, + metadata=dict(name="mock_instance"), + authorized_keys=["ssh-rsa ..."], + environment=Dict(hypervisor=HypervisorType.qemu, trusted_execution=None), + resources=Dict(vcpus=1, memory=2048), + payment=Payment(chain=Chain.ETH, receiver=None, type=PaymentType.hold), + requirements=None, + rootfs=Dict( + parent=Dict(ref=FAKE_STORE_HASH), + size_mib=20480, + ), + volumes=[], + ), + ) + if payg or coco or gpu: + vm.content.metadata["name"] += "_payg" # type: ignore + vm.content.payment = Payment(chain=Chain.AVAX, receiver=FAKE_ADDRESS_EVM, type=PaymentType.superfluid) # type: ignore + vm.content.requirements = Dict( # type: ignore + node=Dict( + node_hash=FAKE_CRN_HASH, + terms_and_conditions=None, + ), + gpu=None, + ) + if coco: + vm.content.metadata["name"] += "_coco" # type: ignore + vm.content.environment.trusted_execution = Dict(firmware=FAKE_STORE_HASH) # type: ignore + if gpu: + vm.content.metadata["name"] += "_gpu" # type: ignore + vm.content.requirements.gpu = [ # type: ignore + Dict( + vendor="NVIDIA", + device_name="RTX 4090", + device_class=GpuDeviceClass.VGA_COMPATIBLE_CONTROLLER, + device_id="abcd:1234", + ) + ] + return vm + + +def create_mock_instance_messages(mock_account): + regular = create_mock_instance_message(mock_account) + payg = create_mock_instance_message(mock_account, payg=True) + coco = create_mock_instance_message(mock_account, coco=True) + gpu = create_mock_instance_message(mock_account, gpu=True) + return AsyncMock(return_value=[regular, payg, coco, gpu]) + +def create_mock_validate_ssh_pubkey_file(): + return MagicMock( + return_value=MagicMock(return_value=FAKE_PUBKEY_FILE, read_text=MagicMock(return_value="ssh-rsa ...")) + ) -def create_test_account() -> MockETHAccount: - return MockETHAccount(private_key=b"deca" * 8) + +def create_mock_fetch_vm_info(): + return AsyncMock( + return_value=[FAKE_VM_HASH, dict(crn_url=FAKE_CRN_URL, allocation_type=help_strings.ALLOCATION_MANUAL)] + ) + + +def create_mock_shutil(): + return MagicMock(which=MagicMock(return_value="/root/.cargo/bin/sevctl", move=MagicMock(return_value="/fake/path"))) + + +def create_mock_client(): + mock_client = AsyncMock(get_message=AsyncMock(return_value=True)) + mock_client_class = MagicMock() + mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) + return mock_client_class, mock_client + + +def create_mock_auth_client(mock_account): + mock_response_get_message = create_mock_instance_message(mock_account, payg=True) + mock_response_create_instance = MagicMock(item_hash=FAKE_VM_HASH) + mock_auth_client = AsyncMock( + get_messages=AsyncMock(), + get_message=AsyncMock(return_value=mock_response_get_message), + create_instance=AsyncMock(return_value=[mock_response_create_instance, 200]), + get_program_price=AsyncMock(return_value=MagicMock(required_tokens=0.0001)), + forget=AsyncMock(return_value=(MagicMock(), 200)), + ) + mock_auth_client_class = MagicMock() + mock_auth_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_auth_client) + return mock_auth_client_class, mock_auth_client + + +def create_mock_vm_client(): + class MockAsyncIteratorLogs: + def __init__(self, *args, **kwargs): + self.items = ['{"message": "Log message 1"}', '{"message": "Log message 2"}'] + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.items: + raise StopAsyncIteration + return self.items.pop(0) + + mock_vm_client = AsyncMock( + start_instance=AsyncMock(return_value=[200, MagicMock()]), + erase_instance=AsyncMock(return_value=[200, MagicMock()]), + reboot_instance=AsyncMock(return_value=[200, MagicMock()]), + stop_instance=AsyncMock(return_value=[200, MagicMock()]), + get_logs=MagicMock(return_value=MockAsyncIteratorLogs()), + ) + mock_vm_client_class = MagicMock() + mock_vm_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_vm_client) + return mock_vm_client_class, mock_vm_client + + +def create_mock_vm_coco_client(): + mock_vm_coco_client = MagicMock( + get_certificates=AsyncMock(return_value=[200, MagicMock()]), + create_session=AsyncMock(), + initialize=AsyncMock(), + close=AsyncMock(), + measurement=AsyncMock(return_value="sev_data"), + validate_measure=AsyncMock(return_value=True), + build_secret=AsyncMock(return_value=["encoded_packet_header", "encoded_secret"]), + inject_secret=AsyncMock(), + ) + mock_vm_coco_client_class = MagicMock(return_value=mock_vm_coco_client) + return mock_vm_coco_client_class, mock_vm_coco_client + + +@pytest.mark.parametrize( + ids=[ + "regular_hold_evm", + "regular_superfluid_evm", + "regular_hold_sol", + "coco_hold_sol", + "coco_hold_evm", + "coco_superfluid_evm", + "gpu_superfluid_evm", + ], + argnames="args, expected", + argvalues=[ + ( # regular_hold_evm + dict( + payment_type="hold", + payment_chain="ETH", + rootfs="debian12", + ), + (FAKE_VM_HASH, None, "ETH"), + ), + ( # regular_superfluid_evm + dict( + payment_type="superfluid", + payment_chain="AVAX", + rootfs="debian12", + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + ), + (FAKE_VM_HASH, FAKE_CRN_URL, "AVAX"), + ), + ( # regular_hold_sol + dict( + payment_type="hold", + payment_chain="SOL", + rootfs="debian12", + ), + (FAKE_VM_HASH, None, "SOL"), + ), + ( # coco_hold_sol + dict( + payment_type="hold", + payment_chain="SOL", + rootfs=FAKE_STORE_HASH, + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + confidential=True, + confidential_firmware=FAKE_STORE_HASH, + ), + (FAKE_VM_HASH, FAKE_CRN_URL, "SOL"), + ), + ( # coco_hold_evm + dict( + payment_type="hold", + payment_chain="ETH", + rootfs=FAKE_STORE_HASH, + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + confidential=True, + confidential_firmware=FAKE_STORE_HASH, + ), + (FAKE_VM_HASH, FAKE_CRN_URL, "ETH"), + ), + ( # coco_superfluid_evm + dict( + payment_type="superfluid", + payment_chain="BASE", + rootfs=FAKE_STORE_HASH, + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + confidential=True, + confidential_firmware=FAKE_STORE_HASH, + ), + (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), + ), + ( # gpu_superfluid_evm + dict( + payment_type="superfluid", + payment_chain="BASE", + rootfs="debian12", + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + gpu=True, + ), + (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), + ), + ], +) +@pytest.mark.asyncio +async def test_create_instance(args, expected): + mock_validate_ssh_pubkey_file = create_mock_validate_ssh_pubkey_file() + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_client_class, _ = create_mock_client() + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + mock_crn_info = create_mock_crn_info() + mock_validated_int_prompt = MagicMock(return_value=1) + mock_wait_for_processed_instance = AsyncMock() + mock_update_flow = AsyncMock(return_value="fake_flow_hash") + mock_wait_for_confirmed_flow = AsyncMock() + + @patch("aleph_client.commands.instance.validate_ssh_pubkey_file", mock_validate_ssh_pubkey_file) + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.AlephHttpClient", mock_client_class) + @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.CRNInfo", mock_crn_info) + @patch("aleph_client.commands.instance.validated_int_prompt", mock_validated_int_prompt) + @patch("aleph_client.commands.instance.wait_for_processed_instance", mock_wait_for_processed_instance) + @patch("aleph_client.commands.instance.update_flow", mock_update_flow) + @patch("aleph_client.commands.instance.wait_for_confirmed_flow", mock_wait_for_confirmed_flow) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def create_instance(instance_spec): + print() # For better display when pytest -v -s + all_args = dict( + ssh_pubkey_file=FAKE_PUBKEY_FILE, + name="mock_instance", + hypervisor=HypervisorType.qemu, + rootfs_size=20480, + vcpus=1, + memory=2048, + timeout_seconds=settings.DEFAULT_VM_TIMEOUT, + skip_volume=True, + persistent_volume=None, + ephemeral_volume=None, + immutable_volume=None, + channel=settings.DEFAULT_CHANNEL, + crn_hash=None, + crn_url=None, + confidential=False, + gpu=False, + private_key=None, + private_key_file=None, + print_message=False, + debug=False, + ) + all_args.update(instance_spec) + return await create(**all_args) + + returned = await create_instance(args) + mock_load_account.assert_called_once() + mock_validate_ssh_pubkey_file.return_value.read_text.assert_called_once() + mock_auth_client.create_instance.assert_called_once() + if args["payment_type"] == "superfluid": + mock_wait_for_processed_instance.assert_called_once() + mock_update_flow.assert_called_once() + mock_wait_for_confirmed_flow.assert_called_once() + mock_vm_client.start_instance.assert_called_once() + assert returned == expected + + +@pytest.mark.asyncio +async def test_list_instances(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_instance_messages = create_mock_instance_messages(mock_account) + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.filter_only_valid_messages", mock_instance_messages) + async def list_instance(): + print() # For better display when pytest -v -s + await list_instances( + address=mock_account.get_address(), + chain=Chain.ETH, + json=False, + debug=False, + ) + mock_instance_messages.assert_called_once() + mock_auth_client.get_messages.assert_called_once() + mock_auth_client.get_program_price.assert_called() + assert mock_auth_client.get_program_price.call_count == 3 + + await list_instance() @pytest.mark.asyncio async def test_delete_instance(): - item_hash = "cafe" * 16 - test_account = create_test_account() - - # Mocking get_flow and delete_flow methods using patch.object - with patch.object(test_account, "get_flow", AsyncMock(return_value={"flowRate": to_wei(123, unit="ether")})): - delete_flow_mock = AsyncMock() - with patch.object(test_account, "delete_flow", delete_flow_mock): - mock_response_message = MagicMock( - sender=test_account.get_address(), - content=MagicMock( - payment=Payment( - chain=Chain.AVAX, - type=PaymentType.superfluid, - receiver=ETHAccount(private_key=b"cafe" * 8).get_address(), - ) - ), - ) + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def delete_instance(): + print() # For better display when pytest -v -s + await delete( + FAKE_VM_HASH, + domain=None, + print_message=False, + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_client.erase_instance.assert_called_once() + mock_account.delete_flow.assert_awaited_once() + mock_auth_client.forget.assert_called_once() + + await delete_instance() + + +@pytest.mark.asyncio +async def test_reboot_instance(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def reboot_instance(): + print() # For better display when pytest -v -s + await reboot( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_client.reboot_instance.assert_called_once() + + await reboot_instance() + + +@pytest.mark.asyncio +async def test_allocate_instance(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def allocate_instance(): + print() # For better display when pytest -v -s + await allocate( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_client.start_instance.assert_called_once() + + await allocate_instance() - mock_client = AsyncMock( - get_message=AsyncMock(return_value=mock_response_message), - get_program_price=AsyncMock(return_value=MagicMock(required_tokens=123)), - forget=AsyncMock(return_value=(MagicMock(), MagicMock())), - ) - mock_client_class = MagicMock() - mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) +@pytest.mark.asyncio +async def test_logs_instance(capsys): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def logs_instance(): + print() # For better display when pytest -v -s + await logs( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_client.get_logs.assert_called_once() + + await logs_instance() + captured = capsys.readouterr() + assert captured.out == "\nLog message 1\nLog message 2\n" + - mock_load_account = MagicMock(return_value=test_account) +@pytest.mark.asyncio +async def test_stop_instance(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_client_class, mock_vm_client = create_mock_vm_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + async def stop_instance(): + print() # For better display when pytest -v -s + await stop( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_client.stop_instance.assert_called_once() + + await stop_instance() - with patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_client_class): - with patch("aleph_client.commands.instance._load_account", mock_load_account): - await delete(item_hash) - # The flow has been deleted since payment uses Superfluid and there is only one flow mocked - delete_flow_mock.assert_awaited_once() +@pytest.mark.asyncio +async def test_confidential_init_session(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_shutil = create_mock_shutil() + mock_vm_coco_client_class, mock_vm_coco_client = create_mock_vm_coco_client() + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.utils.shutil", mock_shutil) + @patch("aleph_client.commands.instance.shutil", mock_shutil) + @patch.object(Path, "exists", MagicMock(return_value=True)) + @patch("aleph_client.commands.instance.VmConfidentialClient", mock_vm_coco_client_class) + async def coco_init_session(): + print() # For better display when pytest -v -s + await confidential_init_session( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + policy=0x1, + keep_session=False, + debug=False, + ) + mock_shutil.which.assert_called_once() + mock_auth_client.get_message.assert_called_once() + mock_vm_coco_client.get_certificates.assert_called_once() + mock_shutil.move.assert_called_once() + mock_vm_coco_client.create_session.assert_called_once() + mock_vm_coco_client.initialize.assert_called_once() + mock_vm_coco_client.close.assert_called_once() + + await coco_init_session() + - # The message has been forgotten - mock_client.forget.assert_called_once() +@pytest.mark.asyncio +async def test_confidential_start(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_shutil = create_mock_shutil() + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_vm_coco_client_class, mock_vm_coco_client = create_mock_vm_coco_client() + mock_calculate_firmware_hash = MagicMock(return_value=FAKE_STORE_HASH) + + @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.utils.shutil", mock_shutil) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch.object(Path, "exists", MagicMock(return_value=True)) + @patch.object(Path, "mkdir", MagicMock()) + @patch("aleph_client.commands.instance.VmConfidentialClient", mock_vm_coco_client_class) + @patch("aleph_client.commands.instance.calculate_firmware_hash", mock_calculate_firmware_hash) + async def coco_start(): + print() # For better display when pytest -v -s + await confidential_start( + FAKE_VM_HASH, + domain=None, + chain=Chain.AVAX, + firmware_hash=None, + firmware_file="/fake/file", + vm_secret="fake_secret", + debug=False, + ) + mock_auth_client.get_message.assert_called_once() + mock_vm_coco_client.measurement.assert_called_once() + mock_calculate_firmware_hash.assert_called_once() + mock_vm_coco_client.validate_measure.assert_called_once() + mock_vm_coco_client.build_secret.assert_called_once() + mock_vm_coco_client.inject_secret.assert_called_once() + mock_vm_coco_client.close.assert_called_once() + + await coco_start() + + +@pytest.mark.parametrize( + ids=[ + "coco_from_scratch", + "coco_from_hash", + ], + argnames="args", + argvalues=[ + dict( # coco_from_scratch + payment_type="superfluid", + payment_chain="AVAX", + crn_hash=FAKE_CRN_HASH, + crn_url=FAKE_CRN_URL, + vcpus=1, + memory=2048, + rootfs=FAKE_STORE_HASH, + rootfs_size=20480, + ), + dict(vm_id=FAKE_VM_HASH), # coco_from_hash + ], +) +@pytest.mark.asyncio +async def test_confidential_create(args): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_shutil = create_mock_shutil() + mock_create = AsyncMock(return_value=[FAKE_VM_HASH, FAKE_CRN_URL, "AVAX"]) + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_client_class, mock_client = create_mock_client() + mock_fetch_vm_info = create_mock_fetch_vm_info() + mock_allocate = AsyncMock(return_value=None) + mock_confidential_init_session = AsyncMock(return_value=None) + mock_confidential_start = AsyncMock() + + @patch("aleph_client.commands.utils.shutil", mock_shutil) + @patch("aleph_client.commands.instance.create", mock_create) + @patch("aleph_client.commands.instance.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_client_class) + @patch("aleph_client.commands.instance.network.fetch_vm_info", mock_fetch_vm_info) + @patch("aleph_client.commands.instance.allocate", mock_allocate) + @patch("aleph_client.commands.instance.confidential_init_session", mock_confidential_init_session) + @patch.object(asyncio, "sleep", AsyncMock()) + @patch("aleph_client.commands.instance.confidential_start", mock_confidential_start) + async def coco_create(instance_spec): + print() # For better display when pytest -v -s + all_args = dict( + vm_id=None, + payment_type=None, + payment_chain=None, + crn_hash=None, + crn_url=None, + ssh_pubkey_file=FAKE_PUBKEY_FILE, + name="mock_instance", + vm_secret="fake_secret", + vcpus=None, + memory=None, + timeout_seconds=settings.DEFAULT_VM_TIMEOUT, + gpu=False, + rootfs=None, + rootfs_size=None, + skip_volume=True, + persistent_volume=None, + ephemeral_volume=None, + immutable_volume=None, + policy=0x1, + confidential_firmware=FAKE_STORE_HASH, + firmware_hash=None, + firmware_file="/fake/file", + keep_session=False, + channel=settings.DEFAULT_CHANNEL, + private_key=None, + private_key_file=None, + debug=False, + ) + all_args.update(instance_spec) + await confidential_create(**all_args) + + await coco_create(args) + mock_shutil.which.assert_called_once() + if len(args) > 1: + mock_create.assert_called_once() + else: + mock_auth_client.get_message.assert_called_once() + mock_client.get_message.assert_called_once() + mock_fetch_vm_info.assert_called_once() + mock_allocate.assert_called_once() + mock_confidential_init_session.assert_called_once() + mock_confidential_start.assert_called_once() diff --git a/tests/unit/test_program.py b/tests/unit/test_program.py new file mode 100644 index 00000000..186a78ea --- /dev/null +++ b/tests/unit/test_program.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +import contextlib +import random +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest +from aleph.sdk.conf import settings +from aleph_message.models import Chain + +from aleph_client.commands.program import ( + delete, + list_programs, + logs, + persist, + runtime_checker, + unpersist, + update, + upload, +) + +from .mocks import ( + FAKE_PROGRAM_HASH, + FAKE_PROGRAM_HASH_2, + FAKE_STORE_HASH, + FAKE_VM_HASH, + Dict, + create_mock_load_account, +) + + +def create_mock_program_message(mock_account, program_item_hash=None, persistent=False, allow_amend=True): + if not program_item_hash: + tmp = list(FAKE_PROGRAM_HASH) + random.shuffle(tmp) + program_item_hash = "".join(tmp) + program = Dict( + chain=Chain.ETH, + sender=mock_account.get_address(), + type="program", + channel="ALEPH-CLOUDSOLUTIONS", + confirmed=True, + item_type="inline", + item_hash=program_item_hash, + content=Dict( + item_type="storage", # for fake store message by convenience + type="vm-function", + address=mock_account.get_address(), + time=1734037086.2333803, + metadata=dict(name="mock_program"), + resources=Dict(vcpus=1, memory=1024, seconds=30), + volumes=[ + Dict(name="immutable", mount="/opt/packages", ref=FAKE_STORE_HASH), + Dict(name="ephemeral", mount="/opt/temp", ephemeral=True, size_mib=1024), + Dict(name="persistent", mount="/opt/utils", persistence=Dict(value="host"), size_mib=1024), + ], + code=Dict(encoding="squashfs", entrypoint="main:app", ref=FAKE_STORE_HASH), + runtime=Dict(ref=FAKE_STORE_HASH), + on=Dict(http=True, persistent=persistent), + allow_amend=allow_amend, + ), + ) + return program + + +def create_mock_program_messages(mock_account): + return AsyncMock( + return_value=[ + create_mock_program_message(mock_account), + create_mock_program_message(mock_account, persistent=True), + ] + ) + + +def create_mock_auth_client(mock_account, swap_persistent=False): + mock_response_get_message = create_mock_program_message(mock_account, persistent=swap_persistent) + mock_response_get_message_2 = create_mock_program_message( + mock_account, program_item_hash=FAKE_PROGRAM_HASH_2, persistent=not swap_persistent + ) + mock_auth_client = AsyncMock( + get_messages=AsyncMock(), + get_message=AsyncMock(return_value=mock_response_get_message), + create_store=AsyncMock(return_value=[MagicMock(item_hash=FAKE_STORE_HASH), 200]), + create_program=AsyncMock(return_value=[MagicMock(item_hash=FAKE_PROGRAM_HASH), 200]), + forget=AsyncMock(return_value=(MagicMock(), 200)), + submit=AsyncMock(return_value=[mock_response_get_message_2, 200, MagicMock()]), + ) + mock_auth_client_class = MagicMock() + mock_auth_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_auth_client) + return mock_auth_client_class, mock_auth_client + + +@contextlib.asynccontextmanager +async def vm_client_operate(vm_id, operation, method="GET"): + yield AsyncMock( + url="https://crn.example.com", + status=200, + json=AsyncMock( + return_value=[ + dict( + __REALTIME_TIMESTAMP="2024-02-02 23:34:21", + MESSAGE="hello world", + ) + ] + ), + ) + + +def create_mock_vm_client(): + mock_vm_client = AsyncMock(operate=vm_client_operate) + mock_vm_client_class = MagicMock() + mock_vm_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_vm_client) + return mock_vm_client_class, mock_vm_client + + +@contextlib.asynccontextmanager +async def mock_client_session_get(self, program_url): + yield AsyncMock( + raise_for_status=MagicMock(), + json=AsyncMock( + return_value={ + "Distribution": "Debian GNU/Linux 12 (bookworm)", + "Python": "3.11.2", + "Docker": "Docker version 20.10.24+dfsg1, build 297e128", + "Nodejs": "v18.13.0", + "Rust": "Not installed", + "Go": "Not installed", + } + ), + ) + + +@pytest.mark.asyncio +async def test_upload_program(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.utils.os.path.isfile", MagicMock(return_value=True)) + @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.program.open", MagicMock()) + async def upload_program(): + print() # For better display when pytest -v -s + returned = await upload( + path=Path("/fake/file.squashfs"), + entrypoint="main:app", + channel=settings.DEFAULT_CHANNEL, + memory=settings.DEFAULT_VM_MEMORY, + vcpus=settings.DEFAULT_VM_VCPUS, + timeout_seconds=settings.DEFAULT_VM_TIMEOUT, + name="mock_program", + runtime=settings.DEFAULT_RUNTIME_ID, + beta=False, + persistent=False, + updatable=True, + skip_volume=True, + skip_env_var=True, + private_key=None, + private_key_file=None, + print_messages=False, + print_code_message=False, + print_program_message=False, + verbose=True, + debug=False, + ) + mock_load_account.assert_called_once() + mock_auth_client.create_store.assert_called_once() + mock_auth_client.create_program.assert_called_once() + assert returned == FAKE_PROGRAM_HASH + + await upload_program() + + +@pytest.mark.asyncio +async def test_update_program(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.utils.os.path.isfile", MagicMock(return_value=True)) + @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.program.open", MagicMock()) + async def update_program(): + print() # For better display when pytest -v -s + await update( + item_hash=FAKE_PROGRAM_HASH, + path=Path("/fake/file.squashfs"), + private_key=None, + private_key_file=None, + print_message=False, + verbose=True, + debug=False, + ) + mock_load_account.assert_called_once() + assert mock_auth_client.get_message.call_count == 2 + mock_auth_client.create_store.assert_called_once() + + await update_program() + + +@pytest.mark.asyncio +async def test_delete_program(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + async def delete_program(): + print() # For better display when pytest -v -s + await delete( + item_hash=FAKE_PROGRAM_HASH, + keep_code=False, + private_key=None, + private_key_file=None, + print_message=False, + verbose=True, + debug=False, + ) + mock_load_account.assert_called_once() + assert mock_auth_client.get_message.call_count == 2 + assert mock_auth_client.forget.call_count == 2 + + await delete_program() + + +@pytest.mark.asyncio +async def test_list_programs(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_program_messages = create_mock_program_messages(mock_account) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.commands.program.AlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.program.filter_only_valid_messages", mock_program_messages) + async def list_program(): + print() # For better display when pytest -v -s + await list_programs( + address=mock_account.get_address(), + private_key=None, + private_key_file=None, + json=False, + debug=False, + ) + mock_program_messages.assert_called_once() + mock_auth_client.get_messages.assert_called_once() + + await list_program() + + +@pytest.mark.asyncio +async def test_persist_program(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + async def persist_program(): + print() # For better display when pytest -v -s + returned = await persist( + item_hash=FAKE_PROGRAM_HASH, + keep_prev=False, + private_key=None, + private_key_file=None, + print_message=False, + verbose=True, + debug=False, + ) + mock_load_account.assert_called_once() + mock_auth_client.get_message.assert_called_once() + mock_auth_client.submit.assert_called_once() + mock_auth_client.forget.assert_called_once() + assert returned == FAKE_PROGRAM_HASH_2 + + await persist_program() + + +@pytest.mark.asyncio +async def test_unpersist_program(): + mock_load_account = create_mock_load_account() + mock_account = mock_load_account.return_value + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account, swap_persistent=True) + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + async def unpersist_program(): + print() # For better display when pytest -v -s + returned = await unpersist( + item_hash=FAKE_PROGRAM_HASH, + keep_prev=False, + private_key=None, + private_key_file=None, + print_message=False, + verbose=True, + debug=False, + ) + mock_load_account.assert_called_once() + mock_auth_client.get_message.assert_called_once() + mock_auth_client.submit.assert_called_once() + mock_auth_client.forget.assert_called_once() + assert returned == FAKE_PROGRAM_HASH_2 + + await unpersist_program() + + +@pytest.mark.asyncio +async def test_logs_program(capsys): + mock_load_account = create_mock_load_account() + mock_vm_client_class, _ = create_mock_vm_client() + + @patch("aleph_client.commands.program._load_account", mock_load_account) + @patch("aleph_client.commands.program.VmClient", mock_vm_client_class) + async def logs_program(): + print() # For better display when pytest -v -s + await logs( + FAKE_VM_HASH, + domain="https://crn.example.com", + chain=Chain.ETH, + debug=False, + ) + + await logs_program() + captured = capsys.readouterr() + assert captured.out == "\nReceived logs\n2024-02-02 23:34:21> hello world\n" + + +@pytest.mark.asyncio +async def test_runtime_checker_program(): + mock_upload = AsyncMock(return_value=FAKE_PROGRAM_HASH) + mock_delete = AsyncMock() + + @patch("aleph_client.commands.program.upload", mock_upload) + @patch.object(aiohttp.ClientSession, "get", mock_client_session_get) + @patch("aleph_client.commands.program.delete", mock_delete) + async def runtime_checker_program(): + print() # For better display when pytest -v -s + await runtime_checker( + item_hash=FAKE_STORE_HASH, + private_key=None, + private_key_file=None, + verbose=True, + debug=False, + ) + mock_upload.assert_called_once() + mock_delete.assert_called_once() + + await runtime_checker_program()