diff --git a/pyproject.toml b/pyproject.toml index 37282041..dbf6e74d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,17 +31,20 @@ dependencies = [ "aiodns==3.2", "aiohttp==3.11.13", "aleph-message>=1.0.5", - "aleph-sdk-python>=2.1", - "base58==2.1.1", # Needed now as default with _load_account changement + #"aleph-sdk-python>=2.1", + "aleph-sdk-python @ git+https://github.com/aleph-im/aleph-sdk-python@andres-feature-implement_ledger_wallet", + "base58==2.1.1", # Needed now as default with _load_account changement "click<8.2", - "py-sr25519-bindings==0.2", # Needed for DOT signatures + "ledgerblue>=0.1.48", + "ledgereth>=0.10", + "py-sr25519-bindings==0.2", # Needed for DOT signatures "pydantic>=2", "pygments==2.19.1", - "pynacl==1.5", # Needed now as default with _load_account changement + "pynacl==1.5", # Needed now as default with _load_account changement "python-magic==0.4.27", "rich==13.9.*", "setuptools>=65.5", - "substrate-interface==1.7.11", # Needed for DOT signatures + "substrate-interface==1.7.11", # Needed for DOT signatures "textual==0.73", "typer==0.15.2", ] diff --git a/src/aleph_client/commands/account.py b/src/aleph_client/commands/account.py index 18157ff4..b9226c68 100644 --- a/src/aleph_client/commands/account.py +++ b/src/aleph_client/commands/account.py @@ -14,6 +14,7 @@ from aleph.sdk.chains.common import generate_key from aleph.sdk.chains.solana import parse_private_key as parse_solana_private_key from aleph.sdk.conf import ( + AccountType, MainConfiguration, load_main_configuration, save_main_configuration, @@ -24,8 +25,11 @@ get_chains_with_super_token, get_compatible_chains, ) +from aleph.sdk.types import AccountFromPrivateKey from aleph.sdk.utils import bytes_from_hex, displayable_amount +from aleph.sdk.wallets.ledger import LedgerETHAccount from aleph_message.models import Chain +from ledgereth.exceptions import LedgerError from rich import box from rich.console import Console from rich.panel import Panel @@ -42,7 +46,7 @@ validated_prompt, yes_no_input, ) -from aleph_client.utils import AsyncTyper, list_unlinked_keys +from aleph_client.utils import AsyncTyper, list_unlinked_keys, load_account logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -145,26 +149,31 @@ async def create( @app.command(name="address") def display_active_address( - private_key: Annotated[Optional[str], typer.Option(help=help_strings.PRIVATE_KEY)] = settings.PRIVATE_KEY_STRING, - private_key_file: Annotated[ - Optional[Path], typer.Option(help=help_strings.PRIVATE_KEY_FILE) - ] = settings.PRIVATE_KEY_FILE, + private_key: Annotated[Optional[str], typer.Option(help=help_strings.PRIVATE_KEY)] = None, + private_key_file: Annotated[Optional[Path], typer.Option(help=help_strings.PRIVATE_KEY_FILE)] = None, ): """ Display your public address(es). """ + # For regular accounts and Ledger accounts + evm_account = load_account(private_key, private_key_file, chain=Chain.ETH) + evm_address = evm_account.get_address() - if private_key is not None: - private_key_file = None - elif private_key_file and not private_key_file.exists(): - typer.secho("No private key available", fg=RED) - raise typer.Exit(code=1) + # For Ledger accounts, the SOL address might not be available + try: + sol_address = load_account(private_key, private_key_file, chain=Chain.SOL).get_address() + except Exception: + sol_address = "Not available (using Ledger device)" - evm_address = _load_account(private_key, private_key_file, chain=Chain.ETH).get_address() - sol_address = _load_account(private_key, private_key_file, chain=Chain.SOL).get_address() + # Detect if it's a Ledger account + config_file_path = Path(settings.CONFIG_FILE) + config = load_main_configuration(config_file_path) + account_type = config.type if config else None + + account_type_str = " (Ledger)" if account_type == AccountType.HARDWARE else "" console.print( - "✉ [bold italic blue]Addresses for Active Account[/bold italic blue] ✉\n\n" + f"✉ [bold italic blue]Addresses for Active Account{account_type_str}[/bold italic blue] ✉\n\n" f"[italic]EVM[/italic]: [cyan]{evm_address}[/cyan]\n" f"[italic]SOL[/italic]: [magenta]{sol_address}[/magenta]\n" ) @@ -229,16 +238,31 @@ def export_private_key( """ Display your private key. """ + # Check if we're using a Ledger account + config_file_path = Path(settings.CONFIG_FILE) + config = load_main_configuration(config_file_path) + + if config and config.type == AccountType.HARDWARE: + typer.secho("Cannot export private key from a Ledger hardware wallet", fg=RED) + typer.secho("The private key remains securely stored on your Ledger device", fg=RED) + raise typer.Exit(code=1) + # Normal private key handling if private_key: private_key_file = None elif private_key_file and not private_key_file.exists(): typer.secho("No private key available", fg=RED) raise typer.Exit(code=1) - evm_pk = _load_account(private_key, private_key_file, chain=Chain.ETH).export_private_key() - sol_pk = _load_account(private_key, private_key_file, chain=Chain.SOL).export_private_key() + eth_account = _load_account(private_key, private_key_file, chain=Chain.ETH) + sol_account = _load_account(private_key, private_key_file, chain=Chain.SOL) + evm_pk = "Not Available" + if isinstance(eth_account, AccountFromPrivateKey): + evm_pk = eth_account.export_private_key() + sol_pk = "Not Available" + if isinstance(sol_account, AccountFromPrivateKey): + sol_pk = sol_account.export_private_key() console.print( "⚠️ [bold italic red]Private Keys for Active Account[/bold italic red] ⚠️\n\n" f"[italic]EVM[/italic]: [cyan]{evm_pk}[/cyan]\n" @@ -261,7 +285,7 @@ def sign_bytes( setup_logging(debug) - account = _load_account(private_key, private_key_file, chain=chain) + account = load_account(private_key, private_key_file, chain=chain) if not message: message = input_multiline() @@ -296,7 +320,7 @@ async def balance( chain: Annotated[Optional[Chain], typer.Option(help=help_strings.ADDRESS_CHAIN)] = None, ): """Display your ALEPH balance and basic voucher information.""" - account = _load_account(private_key, private_key_file, chain=chain) + account = load_account(private_key, private_key_file, chain=chain) if account and not address: address = account.get_address() @@ -381,9 +405,12 @@ async def list_accounts(): table.add_column("Active", no_wrap=True) active_chain = None - if config: + if config and config.path: active_chain = config.chain table.add_row(config.path.stem, str(config.path), "[bold green]*[/bold green]") + elif config and config.address and config.type == AccountType.HARDWARE: + active_chain = config.chain + table.add_row(f"Ledger ({config.address[:8]}...)", "External (Ledger)", "[bold green]*[/bold green]") else: console.print( "[red]No private key path selected in the config file.[/red]\nTo set it up, use: [bold " @@ -395,13 +422,27 @@ async def list_accounts(): if key_file.stem != "default": table.add_row(key_file.stem, str(key_file), "[bold red]-[/bold red]") + # Try to detect Ledger devices + try: + ledger_accounts = LedgerETHAccount.get_accounts() + if ledger_accounts: + for idx, ledger_acc in enumerate(ledger_accounts): + is_active = config and config.type == AccountType.HARDWARE and config.address == ledger_acc.address + status = "[bold green]*[/bold green]" if is_active else "[bold red]-[/bold red]" + table.add_row(f"Ledger #{idx}", f"{ledger_acc.address}", status) + except Exception: + logger.info("No ledger detected") + hold_chains = [*get_chains_with_holding(), Chain.SOL.value] payg_chains = get_chains_with_super_token() active_address = None - if config and config.path and active_chain: - account = _load_account(private_key_path=config.path, chain=active_chain) - active_address = account.get_address() + if config and active_chain: + if config.path: + account = _load_account(private_key_path=config.path, chain=active_chain) + active_address = account.get_address() + elif config.address and config.type == AccountType.HARDWARE: + active_address = config.address console.print( "🌐 [bold italic blue]Chain Infos[/bold italic blue] 🌐\n" @@ -425,7 +466,7 @@ async def vouchers( chain: Annotated[Optional[Chain], typer.Option(help=help_strings.ADDRESS_CHAIN)] = None, ): """Display detailed information about your vouchers.""" - account = _load_account(private_key, private_key_file, chain=chain) + account = load_account(private_key, private_key_file, chain=chain) if account and not address: address = account.get_address() @@ -476,9 +517,16 @@ async def vouchers( async def configure( private_key_file: Annotated[Optional[Path], typer.Option(help="New path to the private key file")] = None, chain: Annotated[Optional[Chain], typer.Option(help="New active chain")] = None, + address: Annotated[Optional[str], typer.Option(help="New active address")] = None, + account_type: Annotated[Optional[AccountType], typer.Option(help="Account type")] = None, ): """Configure current private key file and active chain (default selection)""" + if settings.CONFIG_HOME: + settings.CONFIG_FILE.parent.mkdir(parents=True, exist_ok=True) # Ensure config file is created + private_keys_dir = Path(settings.CONFIG_HOME, "private-keys") # ensure private-keys folder created + private_keys_dir.mkdir(parents=True, exist_ok=True) + unlinked_keys, config = await list_unlinked_keys() # Fixes private key file path @@ -493,8 +541,12 @@ async def configure( typer.secho(f"Private key file not found: {private_key_file}", fg=typer.colors.RED) raise typer.Exit() - # Configures active private key file - if not private_key_file and config and hasattr(config, "path") and Path(config.path).exists(): + # If private_key_file is specified via command line, prioritize it + if private_key_file: + pass + elif not account_type or ( + account_type == AccountType.IMPORTED and config and hasattr(config, "path") and Path(config.path).exists() + ): if not yes_no_input( f"Active private key file: [bright_cyan]{config.path}[/bright_cyan]\n[yellow]Keep current active private " "key?[/yellow]", @@ -520,12 +572,45 @@ async def configure( else: # No change private_key_file = Path(config.path) - if not private_key_file: - typer.secho("No private key file provided or found.", fg=typer.colors.RED) - raise typer.Exit() + if not private_key_file and account_type == AccountType.HARDWARE: + if yes_no_input( + "[bright_cyan]Loading External keys.[/bright_cyan] [yellow]Do you want to import from Ledger?[/yellow]", + default="y", + ): + try: + + accounts = LedgerETHAccount.get_accounts() + account_addresses = [acc.address for acc in accounts] + + console.print("[bold cyan]Available addresses on Ledger:[/bold cyan]") + for idx, account_address in enumerate(account_addresses, start=1): + console.print(f"[{idx}] {account_address}") + + key_choice = Prompt.ask("Choose a address by index") + if key_choice.isdigit(): + key_index = int(key_choice) - 1 + selected_address = account_addresses[key_index] + + if not selected_address: + typer.secho("No valid address selected.", fg=typer.colors.RED) + raise typer.Exit() + + address = selected_address + account_type = AccountType.HARDWARE + except LedgerError as e: + logger.warning(f"Ledger Error : {e.message}") + raise typer.Exit(code=1) from e + except OSError as err: + logger.warning("Please ensure Udev rules are set to use Ledger") + raise typer.Exit(code=1) from err + else: + typer.secho("No private key file provided or found.", fg=typer.colors.RED) + raise typer.Exit() - # Configure active chain - if not chain and config and hasattr(config, "chain"): + # If chain is specified via command line, prioritize it + if chain: + pass + elif config and hasattr(config, "chain"): if not yes_no_input( f"Active chain: [bright_cyan]{config.chain}[/bright_cyan]\n[yellow]Keep current active chain?[/yellow]", default="y", @@ -544,12 +629,15 @@ async def configure( typer.secho("No chain provided.", fg=typer.colors.RED) raise typer.Exit() + if not account_type: + account_type = AccountType.IMPORTED + try: - config = MainConfiguration(path=private_key_file, chain=chain) + config = MainConfiguration(path=private_key_file, chain=chain, address=address, type=account_type) save_main_configuration(settings.CONFIG_FILE, config) console.print( - f"New Default Configuration: [italic bright_cyan]{config.path}[/italic bright_cyan] with [italic " - f"bright_cyan]{config.chain}[/italic bright_cyan]", + f"New Default Configuration: [italic bright_cyan]{config.path or config.address}" + f"[/italic bright_cyan] with [italic bright_cyan]{config.chain}[/italic bright_cyan]", style=typer.colors.GREEN, ) except ValueError as e: diff --git a/src/aleph_client/commands/aggregate.py b/src/aleph_client/commands/aggregate.py index c2848e33..bdcadc94 100644 --- a/src/aleph_client/commands/aggregate.py +++ b/src/aleph_client/commands/aggregate.py @@ -8,10 +8,8 @@ import typer from aiohttp import ClientResponseError, ClientSession -from aleph.sdk.account import _load_account from aleph.sdk.client import AuthenticatedAlephHttpClient from aleph.sdk.conf import settings -from aleph.sdk.types import AccountFromPrivateKey from aleph.sdk.utils import extended_json_encoder from aleph_message.models import Chain, MessageType from aleph_message.status import MessageStatus @@ -21,7 +19,7 @@ from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging -from aleph_client.utils import AsyncTyper, sanitize_url +from aleph_client.utils import AccountTypes, AsyncTyper, load_account, sanitize_url logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -59,7 +57,7 @@ async def forget( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) address = account.get_address() if address is None else address if key == "security" and not is_same_context(): @@ -132,7 +130,7 @@ async def post( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) address = account.get_address() if address is None else address if key == "security" and not is_same_context(): @@ -194,7 +192,7 @@ async def get( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) address = account.get_address() if address is None else address async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: @@ -230,7 +228,7 @@ async def list_aggregates( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) address = account.get_address() if address is None else address aggr_link = f"{sanitize_url(settings.API_HOST)}/api/v0/aggregates/{address}.json" @@ -304,7 +302,7 @@ async def authorize( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) data = await get( key="security", @@ -378,7 +376,7 @@ async def revoke( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) data = await get( key="security", @@ -433,7 +431,7 @@ async def permissions( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) address = account.get_address() if address is None else address data = await get( diff --git a/src/aleph_client/commands/credit.py b/src/aleph_client/commands/credit.py index 54b8dcde..8bf4f56b 100644 --- a/src/aleph_client/commands/credit.py +++ b/src/aleph_client/commands/credit.py @@ -7,7 +7,6 @@ from aleph.sdk import AlephHttpClient from aleph.sdk.account import _load_account from aleph.sdk.conf import settings -from aleph.sdk.types import AccountFromPrivateKey from aleph.sdk.utils import displayable_amount from rich import box from rich.console import Console @@ -17,7 +16,7 @@ from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AccountTypes, AsyncTyper logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -41,7 +40,7 @@ async def show( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = _load_account(private_key, private_key_file) if account and not address: address = account.get_address() @@ -87,7 +86,7 @@ async def history( ): setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = _load_account(private_key, private_key_file) if account and not address: address = account.get_address() diff --git a/src/aleph_client/commands/domain.py b/src/aleph_client/commands/domain.py index 538bdcbd..9161b190 100644 --- a/src/aleph_client/commands/domain.py +++ b/src/aleph_client/commands/domain.py @@ -6,7 +6,6 @@ from typing import Annotated, Optional, cast import typer -from aleph.sdk.account import _load_account from aleph.sdk.client import AlephHttpClient, AuthenticatedAlephHttpClient from aleph.sdk.conf import settings from aleph.sdk.domain import ( @@ -18,7 +17,6 @@ ) from aleph.sdk.exceptions import DomainConfigurationError from aleph.sdk.query.filters import MessageFilter -from aleph.sdk.types import AccountFromPrivateKey from aleph_message.models import AggregateMessage from aleph_message.models.base import MessageType from rich.console import Console @@ -27,7 +25,7 @@ from aleph_client.commands import help_strings from aleph_client.commands.utils import is_environment_interactive -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AccountTypes, AsyncTyper, load_account logger = logging.getLogger(__name__) @@ -65,7 +63,7 @@ async def check_domain_records(fqdn, target, owner): async def attach_resource( - account: AccountFromPrivateKey, + account, fqdn: Hostname, item_hash: Optional[str] = None, catch_all_path: Optional[str] = None, @@ -137,7 +135,7 @@ async def attach_resource( ) -async def detach_resource(account: AccountFromPrivateKey, fqdn: Hostname, interactive: Optional[bool] = None): +async def detach_resource(account: AccountTypes, fqdn: Hostname, interactive: Optional[bool] = None): domain_info = await get_aggregate_domain_info(account, fqdn) interactive = is_environment_interactive() if interactive is None else interactive @@ -187,7 +185,7 @@ async def add( ] = settings.PRIVATE_KEY_FILE, ): """Add and link a Custom Domain.""" - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) interactive = False if (not ask) else is_environment_interactive() console = Console() @@ -272,7 +270,7 @@ async def attach( ] = settings.PRIVATE_KEY_FILE, ): """Attach resource to a Custom Domain.""" - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) await attach_resource( account, @@ -294,7 +292,7 @@ async def detach( ] = settings.PRIVATE_KEY_FILE, ): """Unlink Custom Domain.""" - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) await detach_resource(account, Hostname(fqdn), interactive=False if (not ask) else None) raise typer.Exit() @@ -309,7 +307,7 @@ async def info( ] = settings.PRIVATE_KEY_FILE, ): """Show Custom Domain Details.""" - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) console = Console() domain_validator = DomainValidator() diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index bad66bcb..5ef3a949 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -10,9 +10,8 @@ import typer from aiohttp import ClientResponseError from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient -from aleph.sdk.account import _load_account from aleph.sdk.conf import settings -from aleph.sdk.types import AccountFromPrivateKey, StorageEnum, StoredContent +from aleph.sdk.types import StorageEnum, StoredContent from aleph.sdk.utils import safe_getattr from aleph_message.models import ItemHash, StoreMessage from aleph_message.status import MessageStatus @@ -23,7 +22,7 @@ from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AccountTypes, AsyncTyper, load_account logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -44,7 +43,7 @@ async def pin( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: result: StoreMessage @@ -75,7 +74,7 @@ async def upload( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: if not path.is_file(): @@ -181,7 +180,7 @@ async def forget( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) hashes = [ItemHash(item_hash) for item_hash in item_hash.split(",")] @@ -270,7 +269,7 @@ async def list_files( json: Annotated[bool, typer.Option(help="Print as json instead of rich table")] = False, ): """List all files for a given address""" - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) if account and not address: address = account.get_address() diff --git a/src/aleph_client/commands/instance/__init__.py b/src/aleph_client/commands/instance/__init__.py index bed8b2d5..ec0705b4 100644 --- a/src/aleph_client/commands/instance/__init__.py +++ b/src/aleph_client/commands/instance/__init__.py @@ -11,7 +11,6 @@ import aiohttp import typer from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient -from aleph.sdk.account import _load_account from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.client.services.crn import NetworkGPUS from aleph.sdk.client.services.pricing import Price @@ -82,7 +81,7 @@ yes_no_input, ) from aleph_client.models import CRNInfo -from aleph_client.utils import AsyncTyper, sanitize_url +from aleph_client.utils import AccountTypes, AsyncTyper, load_account, sanitize_url logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -167,7 +166,8 @@ async def create( ssh_pubkey: str = ssh_pubkey_file.read_text(encoding="utf-8").strip() # Populates account / address - account = _load_account(private_key, private_key_file, chain=payment_chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=payment_chain) + address = address or settings.ADDRESS_TO_USE or account.get_address() # Start the fetch in the background (async_lru_cache already returns a future) @@ -830,7 +830,7 @@ async def delete( setup_logging(debug) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: try: existing_message: InstanceMessage = await client.get_message( @@ -942,7 +942,7 @@ async def list_instances( setup_logging(debug) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) address = address or settings.ADDRESS_TO_USE or account.get_address() async with AlephHttpClient(api_server=settings.API_HOST) as client: @@ -979,7 +979,7 @@ async def reboot( or Prompt.ask("URL of the CRN (Compute node) on which the VM is running") ) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) async with VmClient(account, domain) as manager: status, result = await manager.reboot_instance(vm_id=vm_id) @@ -1012,7 +1012,7 @@ async def allocate( or Prompt.ask("URL of the CRN (Compute node) on which the VM will be allocated") ) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) async with VmClient(account, domain) as manager: status, result = await manager.start_instance(vm_id=vm_id) @@ -1040,7 +1040,7 @@ async def logs( domain = (domain and sanitize_url(domain)) or await find_crn_of_vm(vm_id) or Prompt.ask(help_strings.PROMPT_CRN_URL) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) async with VmClient(account, domain) as manager: try: @@ -1071,7 +1071,7 @@ async def stop( domain = (domain and sanitize_url(domain)) or await find_crn_of_vm(vm_id) or Prompt.ask(help_strings.PROMPT_CRN_URL) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) async with VmClient(account, domain) as manager: status, result = await manager.stop_instance(vm_id=vm_id) @@ -1110,7 +1110,7 @@ async def confidential_init_session( or Prompt.ask("URL of the CRN (Compute node) on which the session will be initialized") ) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) sevctl_path = find_sevctl_or_exit() @@ -1187,7 +1187,7 @@ async def confidential_start( session_dir.mkdir(exist_ok=True, parents=True) vm_hash = ItemHash(vm_id) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain=chain) sevctl_path = find_sevctl_or_exit() domain = ( diff --git a/src/aleph_client/commands/instance/port_forwarder.py b/src/aleph_client/commands/instance/port_forwarder.py index 58421402..23c84c9f 100644 --- a/src/aleph_client/commands/instance/port_forwarder.py +++ b/src/aleph_client/commands/instance/port_forwarder.py @@ -7,7 +7,6 @@ import typer from aiohttp import ClientResponseError from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient -from aleph.sdk.account import _load_account from aleph.sdk.conf import settings from aleph.sdk.exceptions import MessageNotProcessed, NotAuthorize from aleph.sdk.types import InstanceManual, PortFlags, Ports @@ -21,7 +20,7 @@ from aleph_client.commands import help_strings from aleph_client.commands.utils import setup_logging -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AccountTypes, AsyncTyper, load_account logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) @@ -42,7 +41,7 @@ async def list_ports( setup_logging(debug) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file) address = address or settings.ADDRESS_TO_USE or account.get_address() async with AlephHttpClient(api_server=settings.API_HOST) as client: @@ -160,7 +159,7 @@ async def create( typer.echo("Error: Port must be between 1 and 65535") raise typer.Exit(code=1) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file) # Create the port flags port_flags = PortFlags(tcp=tcp, udp=udp) @@ -213,7 +212,7 @@ async def update( typer.echo("Error: Port must be between 1 and 65535") raise typer.Exit(code=1) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain) # First check if the port forward exists async with AlephHttpClient(api_server=settings.API_HOST) as client: @@ -293,7 +292,7 @@ async def delete( typer.echo("Error: Port must be between 1 and 65535") raise typer.Exit(code=1) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain) # First check if the port forward exists async with AlephHttpClient(api_server=settings.API_HOST) as client: @@ -376,7 +375,7 @@ async def refresh( setup_logging(debug) - account = _load_account(private_key, private_key_file, chain=chain) + account: AccountTypes = load_account(private_key, private_key_file, chain) try: async with AuthenticatedAlephHttpClient(api_server=settings.API_HOST, account=account) as client: diff --git a/src/aleph_client/commands/message.py b/src/aleph_client/commands/message.py index a7a48d60..94ccd68e 100644 --- a/src/aleph_client/commands/message.py +++ b/src/aleph_client/commands/message.py @@ -11,7 +11,6 @@ import typer from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient -from aleph.sdk.account import _load_account from aleph.sdk.conf import settings from aleph.sdk.exceptions import ( ForgottenMessageError, @@ -20,7 +19,7 @@ ) from aleph.sdk.query.filters import MessageFilter from aleph.sdk.query.responses import MessagesResponse -from aleph.sdk.types import AccountFromPrivateKey, StorageEnum +from aleph.sdk.types import StorageEnum from aleph.sdk.utils import extended_json_encoder from aleph_message.models import AlephMessage, ProgramMessage from aleph_message.models.base import MessageType @@ -35,7 +34,7 @@ setup_logging, str_to_datetime, ) -from aleph_client.utils import AsyncTyper +from aleph_client.utils import AccountTypes, AsyncTyper, load_account app = AsyncTyper(no_args_is_help=True) @@ -138,7 +137,7 @@ async def post( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) storage_engine: StorageEnum content: dict @@ -188,7 +187,7 @@ async def amend( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) async with AlephHttpClient(api_server=settings.API_HOST) as client: existing_message: Optional[AlephMessage] = None @@ -253,7 +252,7 @@ async def forget( hash_list: list[ItemHash] = [ItemHash(h) for h in hashes.split(",")] - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: await client.forget(hashes=hash_list, reason=reason, channel=channel) @@ -296,7 +295,7 @@ def sign( setup_logging(debug) - account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + account: AccountTypes = load_account(private_key, private_key_file) if message is None: message = input_multiline() diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index 4105656f..becdc246 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -24,7 +24,7 @@ ) from aleph.sdk.query.filters import MessageFilter from aleph.sdk.query.responses import PriceResponse -from aleph.sdk.types import AccountFromPrivateKey, StorageEnum, TokenType +from aleph.sdk.types import StorageEnum, TokenType from aleph.sdk.utils import displayable_amount, make_program_content, safe_getattr from aleph_message.models import ( Chain, @@ -127,7 +127,7 @@ async def upload( typer.echo("No such file or directory") raise typer.Exit(code=4) from error - account: AccountFromPrivateKey = _load_account(private_key, private_key_file, chain=payment_chain) + account = _load_account(private_key, private_key_file, chain=payment_chain) address = address or settings.ADDRESS_TO_USE or account.get_address() # Loads default configuration if no chain is set @@ -339,7 +339,7 @@ async def update( typer.echo("No such file or directory") raise typer.Exit(code=4) from error - account: AccountFromPrivateKey = _load_account(private_key, private_key_file, chain=chain) + account = _load_account(private_key, private_key_file, chain=chain) async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: try: diff --git a/src/aleph_client/utils.py b/src/aleph_client/utils.py index cc3c5aaa..997c0fa2 100644 --- a/src/aleph_client/utils.py +++ b/src/aleph_client/utils.py @@ -18,10 +18,18 @@ import aiohttp import typer from aiohttp import ClientSession -from aleph.sdk.conf import MainConfiguration, load_main_configuration, settings +from aleph.sdk.account import AccountTypes, _load_account +from aleph.sdk.conf import ( + AccountType, + MainConfiguration, + load_main_configuration, + settings, +) from aleph.sdk.types import GenericMessage +from aleph_message.models import Chain from aleph_message.models.base import MessageType from aleph_message.models.execution.base import Encoding +from ledgereth.exceptions import LedgerError logger = logging.getLogger(__name__) @@ -190,3 +198,42 @@ def cached_async_function(*args, **kwargs): return ensure_future(async_function(*args, **kwargs)) return cached_async_function + + +def load_account( + private_key_str: Optional[str], private_key_file: Optional[Path], chain: Optional[Chain] = None +) -> AccountTypes: + """ + Two Case Possible + - Account from private key + - Hardware account (ledger) + + We first try to load configurations, if no configurations we fallback to private_key_str / private_key_file. + """ + + # 1st Check for configurations + config_file_path = Path(settings.CONFIG_FILE) + config = load_main_configuration(config_file_path) + + # If no config we try to load private_key_str / private_key_file + if not config: + logger.warning("No config detected fallback to private key") + if private_key_str is not None: + private_key_file = None + + elif private_key_file and not private_key_file.exists(): + logger.error("No account could be retrieved please use `aleph account create` or `aleph account configure`") + raise typer.Exit(code=1) + + if not chain and config: + chain = config.chain + + if config and config.type and config.type == AccountType.HARDWARE: + try: + return _load_account(None, None, chain=chain) + except LedgerError as err: + raise typer.Exit(code=1) from err + except OSError as err: + raise typer.Exit(code=1) from err + else: + return _load_account(private_key_str, private_key_file, chain=chain) diff --git a/tests/unit/test_account_transact.py b/tests/unit/test_account_transact.py index 81a59b1b..3faba2da 100644 --- a/tests/unit/test_account_transact.py +++ b/tests/unit/test_account_transact.py @@ -26,7 +26,7 @@ def test_account_can_transact_success(mock_account): assert mock_account.can_transact() is True -@patch("aleph_client.commands.account._load_account") +@patch("aleph_client.commands.account.load_account") def test_account_can_transact_error_handling(mock_load_account): """Test that error is handled properly when account.can_transact() fails.""" # Setup mock account that will raise InsufficientFundsError diff --git a/tests/unit/test_aggregate.py b/tests/unit/test_aggregate.py index dc03988f..97c75f06 100644 --- a/tests/unit/test_aggregate.py +++ b/tests/unit/test_aggregate.py @@ -67,7 +67,7 @@ async def test_forget(capsys, args): mock_list_aggregates = AsyncMock(return_value=FAKE_AGGREGATE_DATA) mock_auth_client_class, mock_auth_client = create_mock_auth_client() - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.list_aggregates", mock_list_aggregates) @patch("aleph_client.commands.aggregate.AuthenticatedAlephHttpClient", mock_auth_client_class) async def run_forget(aggr_spec): @@ -101,7 +101,7 @@ async def test_post(capsys, args): mock_load_account = create_mock_load_account() mock_auth_client_class, mock_auth_client = create_mock_auth_client() - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.AuthenticatedAlephHttpClient", mock_auth_client_class) async def run_post(aggr_spec): print() # For better display when pytest -v -s @@ -135,7 +135,7 @@ async def test_get(capsys, args, expected): mock_load_account = create_mock_load_account() mock_auth_client_class, mock_auth_client = create_mock_auth_client(return_fetch=FAKE_AGGREGATE_DATA["AI"]) - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.AuthenticatedAlephHttpClient", mock_auth_client_class) async def run_get(aggr_spec): print() # For better display when pytest -v -s @@ -152,7 +152,7 @@ async def run_get(aggr_spec): async def test_list_aggregates(): mock_load_account = create_mock_load_account() - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch.object(aiohttp.ClientSession, "get", mock_client_session_get) async def run_list_aggregates(): print() # For better display when pytest -v -s @@ -169,7 +169,7 @@ async def test_authorize(capsys): mock_get = AsyncMock(return_value=FAKE_AGGREGATE_DATA["security"]) mock_post = AsyncMock(return_value=True) - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.get", mock_get) @patch("aleph_client.commands.aggregate.post", mock_post) async def run_authorize(): @@ -190,7 +190,7 @@ async def test_revoke(capsys): mock_get = AsyncMock(return_value=FAKE_AGGREGATE_DATA["security"]) mock_post = AsyncMock(return_value=True) - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.get", mock_get) @patch("aleph_client.commands.aggregate.post", mock_post) async def run_revoke(): @@ -210,7 +210,7 @@ async def test_permissions(): mock_load_account = create_mock_load_account() mock_get = AsyncMock(return_value=FAKE_AGGREGATE_DATA["security"]) - @patch("aleph_client.commands.aggregate._load_account", mock_load_account) + @patch("aleph_client.commands.aggregate.load_account", mock_load_account) @patch("aleph_client.commands.aggregate.get", mock_get) async def run_permissions(): print() # For better display when pytest -v -s diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 6199d343..ee574935 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -6,7 +6,7 @@ import pytest from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.conf import settings +from aleph.sdk.conf import AccountType, MainConfiguration, settings from aleph.sdk.exceptions import ( ForgottenMessageError, MessageNotFoundError, @@ -14,7 +14,7 @@ ) from aleph.sdk.query.responses import MessagesResponse from aleph.sdk.types import StorageEnum, StoredContent -from aleph_message.models import PostMessage, StoreMessage +from aleph_message.models import Chain, PostMessage, StoreMessage from typer.testing import CliRunner from aleph_client.__main__ import app @@ -202,12 +202,35 @@ def test_account_import_sol(env_files): assert new_key != old_key -def test_account_address(env_files): +@patch("aleph.sdk.wallets.ledger.ethereum.LedgerETHAccount.get_accounts") +def test_account_address(mock_get_accounts, env_files): settings.CONFIG_FILE = env_files[1] result = runner.invoke(app, ["account", "address", "--private-key-file", str(env_files[0])]) assert result.exit_code == 0 assert result.stdout.startswith("✉ Addresses for Active Account ✉\n\nEVM: 0x") + # Test with ledger device + mock_ledger_account = MagicMock() + mock_ledger_account.address = "0xdeadbeef1234567890123456789012345678beef" + mock_ledger_account.get_address.return_value = "0xdeadbeef1234567890123456789012345678beef" + mock_get_accounts.return_value = [mock_ledger_account] + + # Create a ledger config + ledger_config = MainConfiguration( + path=None, chain=Chain.ETH, type=AccountType.HARDWARE, address=mock_ledger_account.address + ) + + with patch("aleph_client.commands.account.load_main_configuration", return_value=ledger_config): + with patch( + "aleph_client.commands.account.load_account", + side_effect=lambda _, __, chain: ( + mock_ledger_account if chain == Chain.ETH else Exception("Ledger doesn't support SOL") + ), + ): + result = runner.invoke(app, ["account", "address"]) + assert result.exit_code == 0 + assert result.stdout.startswith("✉ Addresses for Active Account (Ledger) ✉\n\nEVM: 0x") + def test_account_chain(env_files): settings.CONFIG_FILE = env_files[1] @@ -236,6 +259,22 @@ def test_account_export_private_key(env_files): assert result.stdout.startswith("⚠️ Private Keys for Active Account ⚠️\n\nEVM: 0x") +def test_account_export_private_key_ledger(): + """Test that export-private-key fails for Ledger devices.""" + # Create a ledger config + ledger_config = MainConfiguration( + path=None, chain=Chain.ETH, type=AccountType.HARDWARE, address="0xdeadbeef1234567890123456789012345678beef" + ) + + with patch("aleph_client.commands.account.load_main_configuration", return_value=ledger_config): + result = runner.invoke(app, ["account", "export-private-key"]) + + # Command should fail with appropriate message + assert result.exit_code == 1 + assert "Cannot export private key from a Ledger hardware wallet" in result.stdout + assert "The private key remains securely stored on your Ledger device" in result.stdout + + def test_account_list(env_files): settings.CONFIG_FILE = env_files[1] result = runner.invoke(app, ["account", "list"]) @@ -243,6 +282,43 @@ def test_account_list(env_files): assert result.stdout.startswith("🌐 Chain Infos 🌐") +@patch("aleph.sdk.wallets.ledger.ethereum.LedgerETHAccount.get_accounts") +def test_account_list_with_ledger(mock_get_accounts): + """Test that account list shows Ledger devices when available.""" + # Create mock Ledger accounts + mock_account1 = MagicMock() + mock_account1.address = "0xdeadbeef1234567890123456789012345678beef" + mock_account2 = MagicMock() + mock_account2.address = "0xcafebabe5678901234567890123456789012cafe" + mock_get_accounts.return_value = [mock_account1, mock_account2] + + # Test with no configuration first + with patch("aleph_client.commands.account.load_main_configuration", return_value=None): + result = runner.invoke(app, ["account", "list"]) + assert result.exit_code == 0 + + # Check that the ledger accounts are listed + assert "Ledger #0" in result.stdout + assert "Ledger #1" in result.stdout + assert mock_account1.address in result.stdout + assert mock_account2.address in result.stdout + + # Test with a ledger account that's active in configuration + ledger_config = MainConfiguration( + path=None, chain=Chain.ETH, type=AccountType.HARDWARE, address=mock_account1.address + ) + + with patch("aleph_client.commands.account.load_main_configuration", return_value=ledger_config): + result = runner.invoke(app, ["account", "list"]) + assert result.exit_code == 0 + + # Check that the active ledger account is marked + assert "Ledger" in result.stdout + assert mock_account1.address in result.stdout + # Just check for asterisk since rich formatting tags may not be visible in test output + assert "*" in result.stdout + + def test_account_sign_bytes(env_files): settings.CONFIG_FILE = env_files[1] result = runner.invoke(app, ["account", "sign-bytes", "--message", "test", "--chain", "ETH"]) @@ -377,6 +453,36 @@ def test_account_config(env_files): assert result.stdout.startswith("New Default Configuration: ") +@patch("aleph.sdk.wallets.ledger.ethereum.LedgerETHAccount.get_accounts") +def test_account_config_with_ledger(mock_get_accounts): + """Test configuring account with a Ledger device.""" + # Create mock Ledger accounts + mock_account1 = MagicMock() + mock_account1.address = "0xdeadbeef1234567890123456789012345678beef" + mock_account2 = MagicMock() + mock_account2.address = "0xcafebabe5678901234567890123456789012cafe" + mock_get_accounts.return_value = [mock_account1, mock_account2] + + # Create a temporary config file + with runner.isolated_filesystem(): + config_dir = Path("test_config") + config_dir.mkdir() + config_file = config_dir / "config.json" + + with ( + patch("aleph.sdk.conf.settings.CONFIG_FILE", config_file), + patch("aleph.sdk.conf.settings.CONFIG_HOME", str(config_dir)), + patch("aleph_client.commands.account.Prompt.ask", return_value="1"), + patch("aleph_client.commands.account.yes_no_input", return_value=True), + ): + + result = runner.invoke(app, ["account", "config", "--account-type", "hardware", "--chain", "ETH"]) + + assert result.exit_code == 0 + assert "New Default Configuration" in result.stdout + assert mock_account1.address in result.stdout + + def test_message_get(mocker, store_message_fixture): # Use subprocess to avoid border effects between tests caused by the initialisation # of the aiohttp client session out of an async context in the SDK. This avoids diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index a050c1fd..ae335171 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -531,7 +531,7 @@ async def test_create_instance(args, expected, mock_crn_list_obj, mock_pricing_i # Setup all required patches with ( patch("aleph_client.commands.instance.validate_ssh_pubkey_file", mock_validate_ssh_pubkey_file), - patch("aleph_client.commands.instance._load_account", mock_load_account), + patch("aleph_client.commands.instance.load_account", mock_load_account), patch("aleph_client.commands.instance.AlephHttpClient", mock_client_class), patch("aleph_client.commands.pricing.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class), @@ -620,7 +620,7 @@ async def test_list_instances(mock_crn_list_obj, mock_pricing_info_response, moc ) # Setup all patches - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.fetch_latest_crn_version", mock_fetch_latest_crn_version) @patch("aleph_client.commands.files.AlephHttpClient", mock_client_class) @patch("aleph_client.commands.instance.AlephHttpClient", mock_auth_client_class) @@ -657,7 +657,7 @@ async def test_delete_instance(mock_api_response): # We need to mock that there is no CRN information to skip VM erasure mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=MagicMock(root={}))) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) @patch("aleph_client.commands.instance.fetch_settings", mock_fetch_settings) @@ -709,7 +709,7 @@ async def test_delete_instance_with_insufficient_funds(): } ) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) @patch("aleph_client.commands.instance.fetch_settings", mock_fetch_settings) @@ -753,7 +753,7 @@ async def test_delete_instance_with_detailed_insufficient_funds_error(capsys, mo } ) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) @patch("aleph_client.commands.instance.fetch_settings", mock_fetch_settings) @@ -794,7 +794,7 @@ async def test_reboot_instance(): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) async def reboot_instance(): @@ -826,7 +826,7 @@ async def test_allocate_instance(): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) async def allocate_instance(): @@ -858,7 +858,7 @@ async def test_logs_instance(capsys): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) async def logs_instance(): @@ -892,7 +892,7 @@ async def test_stop_instance(): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) async def stop_instance(): @@ -925,7 +925,7 @@ async def test_confidential_init_session(): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.utils.shutil", mock_shutil) @patch("aleph_client.commands.instance.shutil", mock_shutil) @@ -967,7 +967,7 @@ async def test_confidential_start(): # Add the mock to the auth client mock_auth_client.instance = MagicMock(get_instances_allocations=AsyncMock(return_value=mock_allocation)) - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.utils.shutil", mock_shutil) @patch("aleph_client.commands.instance.network.AlephHttpClient", mock_auth_client_class) @patch.object(Path, "exists", MagicMock(return_value=True)) @@ -1090,7 +1090,7 @@ async def test_gpu_create_no_gpus_available(mock_crn_list_obj, mock_pricing_info mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version() mock_validated_prompt = MagicMock(return_value="1") - @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.load_account", mock_load_account) @patch("aleph_client.commands.instance.validate_ssh_pubkey_file", mock_validate_ssh_pubkey_file) @patch("aleph_client.commands.instance.AlephHttpClient", mock_client_class) @patch("aleph_client.commands.pricing.AlephHttpClient", mock_client_class) diff --git a/tests/unit/test_port_forwarder.py b/tests/unit/test_port_forwarder.py index a7387e64..5c782313 100644 --- a/tests/unit/test_port_forwarder.py +++ b/tests/unit/test_port_forwarder.py @@ -98,7 +98,7 @@ async def test_list_ports(mock_auth_setup): mock_console = MagicMock() with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.Console", return_value=mock_console), ): @@ -118,7 +118,7 @@ async def test_list_ports(mock_auth_setup): ) with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, patch("aleph_client.commands.instance.port_forwarder.typer.Exit", side_effect=SystemExit), @@ -142,7 +142,7 @@ async def test_create_port(mock_auth_setup): mock_client_class = mock_auth_setup["mock_client_class"] with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, ): @@ -177,7 +177,7 @@ async def test_update_port(mock_auth_setup): mock_client.port_forwarder.get_ports.return_value = mock_existing_ports with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -211,7 +211,7 @@ async def test_delete_port(mock_auth_setup): mock_client.port_forwarder.get_ports.return_value = mock_existing_ports with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -236,7 +236,7 @@ async def test_delete_port(mock_auth_setup): mock_client.port_forwarder.delete_ports.reset_mock() with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -268,7 +268,7 @@ async def test_delete_port_last_port(mock_auth_setup): mock_client.port_forwarder.update_ports = None with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -310,7 +310,7 @@ async def test_refresh_port(mock_auth_setup): mock_client.instance.get_instance_allocation_info.return_value = (None, mock_allocation) with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, ): @@ -340,7 +340,7 @@ async def test_refresh_port_no_allocation(mock_auth_setup): mock_client.instance.get_instance_allocation_info.return_value = (None, None) with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, patch("aleph_client.commands.instance.port_forwarder.typer.Exit", side_effect=SystemExit), @@ -376,7 +376,7 @@ async def test_refresh_port_scheduler_allocation(mock_auth_setup): mock_client.instance.get_instance_allocation_info.return_value = (None, mock_allocation) with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, ): @@ -415,7 +415,7 @@ async def test_non_processed_message_statuses(): mock_http_client.port_forwarder.get_ports = AsyncMock(return_value=mock_existing_ports) with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_http_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_auth_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -432,7 +432,7 @@ async def test_non_processed_message_statuses(): mock_echo.reset_mock() with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_http_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_auth_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo, @@ -450,7 +450,7 @@ async def test_non_processed_message_statuses(): mock_echo.reset_mock() with ( - patch("aleph_client.commands.instance.port_forwarder._load_account", mock_load_account), + patch("aleph_client.commands.instance.port_forwarder.load_account", mock_load_account), patch("aleph_client.commands.instance.port_forwarder.AlephHttpClient", mock_http_client_class), patch("aleph_client.commands.instance.port_forwarder.AuthenticatedAlephHttpClient", mock_auth_client_class), patch("aleph_client.commands.instance.port_forwarder.typer.echo") as mock_echo,