From b464bc5d818ae25c40446b8734d6f2a9dc29976a Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 21:54:48 +0200 Subject: [PATCH 01/14] Problem: A user cannot initialize an already created confidential VM. Solution: Implement `VmConfidentialClient` class to be able to initialize and interact with confidential VMs. --- pyproject.toml | 1 + .../sdk/client/{vmclient.py => vm_client.py} | 0 .../sdk/client/vm_confidential_client.py | 155 ++++++++++++++++ src/aleph/sdk/utils.py | 31 ++++ .../{test_vmclient.py => test_vm_client.py} | 2 +- tests/unit/test_vm_confidential_client.py | 172 ++++++++++++++++++ 6 files changed, 360 insertions(+), 1 deletion(-) rename src/aleph/sdk/client/{vmclient.py => vm_client.py} (100%) create mode 100644 src/aleph/sdk/client/vm_confidential_client.py rename tests/unit/{test_vmclient.py => test_vm_client.py} (99%) create mode 100644 tests/unit/test_vm_confidential_client.py diff --git a/pyproject.toml b/pyproject.toml index b52efe66..1070a7f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "python-magic", "typer", "typing_extensions", + "aioresponses>=0.7.6" ] [project.optional-dependencies] diff --git a/src/aleph/sdk/client/vmclient.py b/src/aleph/sdk/client/vm_client.py similarity index 100% rename from src/aleph/sdk/client/vmclient.py rename to src/aleph/sdk/client/vm_client.py diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py new file mode 100644 index 00000000..305ff8ef --- /dev/null +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -0,0 +1,155 @@ +import json +import logging +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import aiohttp +from aleph_message.models import ItemHash + +from aleph.sdk.client.vm_client import VmClient +from aleph.sdk.types import Account +from aleph.sdk.utils import run_in_subprocess + +logger = logging.getLogger(__name__) + + +class VmConfidentialClient(VmClient): + sevctl_path: Path + + def __init__( + self, + account: Account, + sevctl_path: Path, + node_url: str = "", + session: Optional[aiohttp.ClientSession] = None, + ): + super().__init__(account, node_url, session) + self.sevctl_path = sevctl_path + + async def get_certificates(self) -> Tuple[Optional[int], str]: + url = f"{self.node_url}/about/certificates" + try: + async with self.session.get(url) as response: + data = await response.read() + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(data) + return response.status, tmp_file.name + + except aiohttp.ClientError as e: + logger.error( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + return None, str(e) + + async def create_session( + self, vm_id: ItemHash, certificate_path: Path, policy: int + ): + args = [ + "session", + "--name", + vm_id, + str(certificate_path), + str(policy), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(args) + except Exception as e: + raise ValueError(f"Session creation have failed, reason: {str(e)}") + + async def initialize( + self, vm_id: ItemHash, session: Path, godh: Path + ) -> Tuple[Optional[int], str]: + session_file = session.read_bytes() + godh_file = godh.read_bytes() + params = { + "session": session_file, + "godh": godh_file, + } + return await self.perform_confidential_operation( + vm_id, "confidential/initialize", params=params + ) + + async def measurement(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: + status, text = await self.perform_confidential_operation( + vm_id, "confidential/measurement" + ) + if status: + response = json.loads(text) + return status, response + + return status, text + + async def validate_measurement(self, vm_id: ItemHash) -> bool: + # TODO: Implement measurement validation + return True + + async def build_secret( + self, tek_path: Path, tik_path: Path, measurement: str, secret: str + ) -> Tuple[Path, Path]: + current_path = Path().cwd() + secret_header_path = current_path / "secret_header.bin" + secret_payload_path = current_path / "secret_payload.bin" + args = [ + "secret", + "build", + "--tik", + str(tik_path), + "--tek", + str(tek_path), + "--launch-measure-blob", + measurement, + "--secret", + secret, + str(secret_header_path), + str(secret_payload_path), + ] + try: + # TODO: Check command result + await self.sevctl_cmd(args) + return secret_header_path, secret_payload_path + except Exception as e: + raise ValueError(f"Secret building have failed, reason: {str(e)}") + + async def inject_secret( + self, vm_id: ItemHash, packed_header: str, secret: str + ) -> Tuple[Optional[int], str]: + params = { + "packed_header": packed_header, + "secret": secret, + } + status, text = await self.perform_confidential_operation( + vm_id, "confidential/inject_secret", params=params + ) + + if status: + response = json.loads(text) + return status, response + + return status, text + + async def perform_confidential_operation( + self, vm_id: ItemHash, operation: str, params: Optional[Dict[str, Any]] = None + ) -> Tuple[Optional[int], str]: + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) + + url, header = await self._generate_header(vm_id=vm_id, operation=operation) + + try: + async with self.session.post(url, headers=header, data=params) as response: + response_text = await response.text() + return response.status, response_text + + except aiohttp.ClientError as e: + logger.error(f"HTTP error during operation {operation}: {str(e)}") + return None, str(e) + + async def sevctl_cmd(self, *args) -> bytes: + return await run_in_subprocess( + ["sevctl", *args], + check=True, + ) diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 2d1b30c7..130edc38 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,8 +1,10 @@ +import asyncio import errno import hashlib import json import logging import os +import subprocess from datetime import date, datetime, time from enum import Enum from pathlib import Path @@ -11,6 +13,7 @@ Any, Dict, Iterable, + List, Mapping, Optional, Protocol, @@ -220,3 +223,31 @@ def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: } ) return signed_operation + + +async def run_in_subprocess( + command: List[str], check: bool = True, stdin_input: Optional[bytes] = None +) -> bytes: + """Run the specified command in a subprocess, returns the stdout of the process.""" + logger.debug(f"command: {' '.join(command)}") + + process = await asyncio.create_subprocess_exec( + *command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate(input=stdin_input) + + if check and process.returncode: + logger.error( + f"Command failed with error code {process.returncode}:\n" + f" stdin = {stdin_input!r}\n" + f" command = {command}\n" + f" stdout = {stderr!r}" + ) + raise subprocess.CalledProcessError( + process.returncode, str(command), stderr.decode() + ) + + return stdout diff --git a/tests/unit/test_vmclient.py b/tests/unit/test_vm_client.py similarity index 99% rename from tests/unit/test_vmclient.py rename to tests/unit/test_vm_client.py index d0198c36..6dcd9fdb 100644 --- a/tests/unit/test_vmclient.py +++ b/tests/unit/test_vm_client.py @@ -9,7 +9,7 @@ from yarl import URL from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.client.vmclient import VmClient +from aleph.sdk.client.vm_client import VmClient from .aleph_vm_authentication import ( SignedOperation, diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py new file mode 100644 index 00000000..a7f79dd6 --- /dev/null +++ b/tests/unit/test_vm_confidential_client.py @@ -0,0 +1,172 @@ +import tempfile +from pathlib import Path +from unittest.mock import patch + +import aiohttp +import pytest +from aioresponses import aioresponses +from aleph_message.models import ItemHash + +from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.client.vm_confidential_client import VmConfidentialClient + + +@pytest.mark.asyncio +async def test_perform_confidential_operation(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/test" + + with aioresponses() as m: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url="http://localhost", + session=aiohttp.ClientSession(), + ) + m.post( + f"http://localhost/control/machine/{vm_id}/{operation}", + status=200, + payload="mock_response_text", + ) + + status, response_text = await vm_client.perform_confidential_operation( + vm_id, operation + ) + assert status == 200 + assert response_text == "mock_response_text" + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_initialize_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/initialize" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_bytes = Path(tmp_file.name).read_bytes() + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + tmp_file_path = Path(tmp_file.name) + status, response_text = await vm_client.initialize( + vm_id, session=tmp_file_path, godh=tmp_file_path + ) + assert status == 200 + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + data={ + "session": tmp_file_bytes, + "godh": tmp_file_bytes, + }, + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_measurement_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/measurement" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + status, response_text = await vm_client.measurement(vm_id) + assert status == 200 + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + headers=headers, + ) + await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_confidential_inject_secret_instance(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + operation = "confidential/inject_secret" + node_url = "http://localhost" + url = f"{node_url}/control/machine/{vm_id}/{operation}" + headers = { + "X-SignedPubKey": "test_pubkey_token", + "X-SignedOperation": "test_operation_token", + } + test_secret = "test_secret" + packed_header = "test_packed_header" + + with aioresponses() as m: + with patch( + "aleph.sdk.client.vm_confidential_client.VmConfidentialClient._generate_header", + return_value=(url, headers), + ): + vm_client = VmConfidentialClient( + account=account, + sevctl_path=Path("/"), + node_url=node_url, + session=aiohttp.ClientSession(), + ) + m.post( + url, + status=200, + payload="mock_response_text", + ) + status, response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packed_header=packed_header + ) + assert status == 200 + assert response_text == "mock_response_text" + m.assert_called_once_with( + url, + method="POST", + data={ + "secret": test_secret, + "packed_header": packed_header, + }, + headers=headers, + ) + await vm_client.session.close() From 319877349f7aeaf46e848aaeb667de19401a9918 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 21:57:42 +0200 Subject: [PATCH 02/14] Fix: Solve test responses to pass on the CI. --- tests/unit/test_vm_confidential_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index a7f79dd6..5a0510fe 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -34,7 +34,7 @@ async def test_perform_confidential_operation(): vm_id, operation ) assert status == 200 - assert response_text == "mock_response_text" + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() @@ -73,7 +73,7 @@ async def test_confidential_initialize_instance(): vm_id, session=tmp_file_path, godh=tmp_file_path ) assert status == 200 - assert response_text == "mock_response_text" + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", @@ -116,7 +116,7 @@ async def test_confidential_measurement_instance(): ) status, response_text = await vm_client.measurement(vm_id) assert status == 200 - assert response_text == "mock_response_text" + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", @@ -159,7 +159,7 @@ async def test_confidential_inject_secret_instance(): vm_id, secret=test_secret, packed_header=packed_header ) assert status == 200 - assert response_text == "mock_response_text" + assert response_text == '"mock_response_text"' # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", From 6b2166aa08febbaa56601c16ffdaef8d39eba510 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 21:59:10 +0200 Subject: [PATCH 03/14] Fix: Solve code quality issues. --- tests/unit/test_vm_confidential_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index 5a0510fe..19f92db4 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -73,7 +73,9 @@ async def test_confidential_initialize_instance(): vm_id, session=tmp_file_path, godh=tmp_file_path ) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert ( + response_text == '"mock_response_text"' + ) # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", From a3047e64b509755dcf3e169acaa93aa659911c86 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 22:05:22 +0200 Subject: [PATCH 04/14] Fix: Solve code quality issues. --- tests/unit/test_vm_confidential_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index 19f92db4..3a4a9337 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -118,7 +118,7 @@ async def test_confidential_measurement_instance(): ) status, response_text = await vm_client.measurement(vm_id) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == 'mock_response_text' # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", @@ -161,7 +161,7 @@ async def test_confidential_inject_secret_instance(): vm_id, secret=test_secret, packed_header=packed_header ) assert status == 200 - assert response_text == '"mock_response_text"' # ' ' cause by aioresponses + assert response_text == 'mock_response_text' # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", From f90729447061772d2ddd837709366c3b4a563448 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 22:06:47 +0200 Subject: [PATCH 05/14] Fix: Solve code quality issues. --- tests/unit/test_vm_confidential_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index 3a4a9337..3b1ea619 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -118,7 +118,7 @@ async def test_confidential_measurement_instance(): ) status, response_text = await vm_client.measurement(vm_id) assert status == 200 - assert response_text == 'mock_response_text' # ' ' cause by aioresponses + assert response_text == "mock_response_text" # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", @@ -161,7 +161,7 @@ async def test_confidential_inject_secret_instance(): vm_id, secret=test_secret, packed_header=packed_header ) assert status == 200 - assert response_text == 'mock_response_text' # ' ' cause by aioresponses + assert response_text == "mock_response_text" # ' ' cause by aioresponses m.assert_called_once_with( url, method="POST", From 0c815c3b39fcd3592c3908e6f5e911688b4f2c1b Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Tue, 2 Jul 2024 22:07:58 +0200 Subject: [PATCH 06/14] Fix: Remove useless comments. --- tests/unit/test_vm_confidential_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index 3b1ea619..e51c9af6 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -118,7 +118,7 @@ async def test_confidential_measurement_instance(): ) status, response_text = await vm_client.measurement(vm_id) assert status == 200 - assert response_text == "mock_response_text" # ' ' cause by aioresponses + assert response_text == "mock_response_text" m.assert_called_once_with( url, method="POST", @@ -161,7 +161,7 @@ async def test_confidential_inject_secret_instance(): vm_id, secret=test_secret, packed_header=packed_header ) assert status == 200 - assert response_text == "mock_response_text" # ' ' cause by aioresponses + assert response_text == "mock_response_text" m.assert_called_once_with( url, method="POST", From d5b6a42710a24db02146d61623b2739f2d5a576a Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Wed, 3 Jul 2024 16:14:10 +0200 Subject: [PATCH 07/14] Fix: Added 2 new missing tests to check every feature. --- .../sdk/client/vm_confidential_client.py | 12 +-- tests/unit/test_vm_confidential_client.py | 80 +++++++++++++++++++ 2 files changed, 87 insertions(+), 5 deletions(-) diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py index 305ff8ef..7a5d76f1 100644 --- a/src/aleph/sdk/client/vm_confidential_client.py +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -44,17 +44,19 @@ async def get_certificates(self) -> Tuple[Optional[int], str]: async def create_session( self, vm_id: ItemHash, certificate_path: Path, policy: int - ): + ) -> Path: + current_path = Path().cwd() args = [ "session", "--name", - vm_id, + str(vm_id), str(certificate_path), str(policy), ] try: # TODO: Check command result - await self.sevctl_cmd(args) + await self.sevctl_cmd(*args) + return current_path except Exception as e: raise ValueError(f"Session creation have failed, reason: {str(e)}") @@ -107,7 +109,7 @@ async def build_secret( ] try: # TODO: Check command result - await self.sevctl_cmd(args) + await self.sevctl_cmd(*args) return secret_header_path, secret_payload_path except Exception as e: raise ValueError(f"Secret building have failed, reason: {str(e)}") @@ -150,6 +152,6 @@ async def perform_confidential_operation( async def sevctl_cmd(self, *args) -> bytes: return await run_in_subprocess( - ["sevctl", *args], + [str(self.sevctl_path), *args], check=True, ) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index e51c9af6..6dac8836 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -1,5 +1,6 @@ import tempfile from pathlib import Path +from unittest import mock from unittest.mock import patch import aiohttp @@ -172,3 +173,82 @@ async def test_confidential_inject_secret_instance(): headers=headers, ) await vm_client.session.close() + + +@pytest.mark.asyncio +async def test_create_session_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + certificates_path = Path("/") + policy = 1 + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + _ = await vm_client.create_session(vm_id, certificates_path, policy) + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "session", + "--name", + str(vm_id), + str(certificates_path), + str(policy), + ], + check=True, + ) + + +@pytest.mark.asyncio +async def test_build_secret_command(): + account = ETHAccount(private_key=b"0x" + b"1" * 30) + node_url = "http://localhost" + sevctl_path = Path("/usr/bin/sevctl") + current_path = Path().cwd() + measurement = "test_measurement" + secret = "test_secret" + expected_secret_header_path = current_path / "secret_header.bin" + expected_secret_payload_path = current_path / "secret_payload.bin" + + with mock.patch( + "aleph.sdk.client.vm_confidential_client.run_in_subprocess", + return_value=True, + ) as export_mock: + vm_client = VmConfidentialClient( + account=account, + sevctl_path=sevctl_path, + node_url=node_url, + session=aiohttp.ClientSession(), + ) + secret_header_path, secret_payload_path = await vm_client.build_secret( + current_path, current_path, measurement, secret + ) + assert expected_secret_header_path == secret_header_path + assert expected_secret_payload_path == secret_payload_path + export_mock.assert_called_once_with( + [ + str(sevctl_path), + "secret", + "build", + "--tik", + str(current_path), + "--tek", + str(current_path), + "--launch-measure-blob", + measurement, + "--secret", + secret, + str(expected_secret_header_path), + str(expected_secret_payload_path), + ], + check=True, + ) From c15f7fef2e1d4d5b9a96cc472fdb8101bde274e3 Mon Sep 17 00:00:00 2001 From: "Andres D. Molins" Date: Wed, 3 Jul 2024 19:28:44 +0200 Subject: [PATCH 08/14] Fix: Implemented measure validation and refactored secret building. --- .../sdk/client/vm_confidential_client.py | 166 ++++++++++++------ src/aleph/sdk/types.py | 19 ++ src/aleph/sdk/utils.py | 126 ++++++++++++- tests/unit/test_vm_confidential_client.py | 87 +++------ 4 files changed, 280 insertions(+), 118 deletions(-) diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py index 7a5d76f1..7373f957 100644 --- a/src/aleph/sdk/client/vm_confidential_client.py +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -1,5 +1,7 @@ +import base64 import json import logging +import os import tempfile from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -8,8 +10,15 @@ from aleph_message.models import ItemHash from aleph.sdk.client.vm_client import VmClient -from aleph.sdk.types import Account -from aleph.sdk.utils import run_in_subprocess +from aleph.sdk.types import Account, SEVMeasurement +from aleph.sdk.utils import ( + compute_confidential_measure, + encrypt_secret_table, + get_vm_measure, + make_packet_header, + make_secret_table, + run_in_subprocess, +) logger = logging.getLogger(__name__) @@ -28,6 +37,10 @@ def __init__( self.sevctl_path = sevctl_path async def get_certificates(self) -> Tuple[Optional[int], str]: + """ + Get platform confidential certificate + """ + url = f"{self.node_url}/about/certificates" try: async with self.session.get(url) as response: @@ -45,6 +58,10 @@ async def get_certificates(self) -> Tuple[Optional[int], str]: async def create_session( self, vm_id: ItemHash, certificate_path: Path, policy: int ) -> Path: + """ + Create new confidential session + """ + current_path = Path().cwd() args = [ "session", @@ -60,9 +77,11 @@ async def create_session( except Exception as e: raise ValueError(f"Session creation have failed, reason: {str(e)}") - async def initialize( - self, vm_id: ItemHash, session: Path, godh: Path - ) -> Tuple[Optional[int], str]: + async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str: + """ + Initialize Confidential VM negociation passing the needed session files + """ + session_file = session.read_bytes() godh_file = godh.read_bytes() params = { @@ -73,67 +92,103 @@ async def initialize( vm_id, "confidential/initialize", params=params ) - async def measurement(self, vm_id: ItemHash) -> Tuple[Optional[int], str]: - status, text = await self.perform_confidential_operation( - vm_id, "confidential/measurement" - ) - if status: - response = json.loads(text) - return status, response + async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: + """ + Fetch VM confidential measurement + """ - return status, text + if not self.pubkey_signature_header: + self.pubkey_signature_header = ( + await self._generate_pubkey_signature_header() + ) - async def validate_measurement(self, vm_id: ItemHash) -> bool: - # TODO: Implement measurement validation - return True + operation = "confidential/measurement" + url, header = await self._generate_header(vm_id=vm_id, operation=operation) + try: + async with self.session.get(url, headers=header) as response: + response = await response.json() + print(response) + sev_mesurement = SEVMeasurement.parse_obj(response) + return sev_mesurement + + except aiohttp.ClientError as e: + raise ValueError( + f"HTTP error getting node certificates on {self.node_url}: {str(e)}" + ) + + async def validate_measure( + self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str + ) -> bool: + """ + Validate VM confidential measurement + """ + + tik = tik_path.read_bytes() + vm_measure, nonce = get_vm_measure(sev_data) + + expected_measure = compute_confidential_measure( + sev_info=sev_data.sev_info, + tik=tik, + expected_hash=firmware_hash, + nonce=nonce, + ).digest() + return expected_measure == vm_measure async def build_secret( - self, tek_path: Path, tik_path: Path, measurement: str, secret: str - ) -> Tuple[Path, Path]: - current_path = Path().cwd() - secret_header_path = current_path / "secret_header.bin" - secret_payload_path = current_path / "secret_payload.bin" - args = [ - "secret", - "build", - "--tik", - str(tik_path), - "--tek", - str(tek_path), - "--launch-measure-blob", - measurement, - "--secret", - secret, - str(secret_header_path), - str(secret_payload_path), - ] - try: - # TODO: Check command result - await self.sevctl_cmd(*args) - return secret_header_path, secret_payload_path - except Exception as e: - raise ValueError(f"Secret building have failed, reason: {str(e)}") + self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str + ) -> Tuple[str, str]: + """ + Build disk secret to be injected on the confidential VM + """ + + tek = tek_path.read_bytes() + tik = tik_path.read_bytes() + + vm_measure, _ = get_vm_measure(sev_data) + + iv = os.urandom(16) + secret_table = make_secret_table(secret) + encrypted_secret_table = encrypt_secret_table( + secret_table=secret_table, tek=tek, iv=iv + ) + + packet_header = make_packet_header( + vm_measure=vm_measure, + encrypted_secret_table=encrypted_secret_table, + secret_table_size=len(secret_table), + tik=tik, + iv=iv, + ) + + encoded_packet_header = base64.b64encode(packet_header).decode() + encoded_secret = base64.b64encode(encrypted_secret_table).decode() + + return encoded_packet_header, encoded_secret async def inject_secret( - self, vm_id: ItemHash, packed_header: str, secret: str - ) -> Tuple[Optional[int], str]: + self, vm_id: ItemHash, packet_header: str, secret: str + ) -> Dict: + """ + Send the secret by the encrypted channel to boot up the VM + """ + params = { - "packed_header": packed_header, + "packet_header": packet_header, "secret": secret, } - status, text = await self.perform_confidential_operation( + text = await self.perform_confidential_operation( vm_id, "confidential/inject_secret", params=params ) - if status: - response = json.loads(text) - return status, response - - return status, text + return json.loads(text) async def perform_confidential_operation( self, vm_id: ItemHash, operation: str, params: Optional[Dict[str, Any]] = None - ) -> Tuple[Optional[int], str]: + ) -> str: + """ + Send confidential operations to the CRN passing the auth headers on each request + """ + if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() @@ -144,13 +199,16 @@ async def perform_confidential_operation( try: async with self.session.post(url, headers=header, data=params) as response: response_text = await response.text() - return response.status, response_text + return response_text except aiohttp.ClientError as e: - logger.error(f"HTTP error during operation {operation}: {str(e)}") - return None, str(e) + raise ValueError(f"HTTP error during operation {operation}: {str(e)}") async def sevctl_cmd(self, *args) -> bytes: + """ + Execute `sevctl` command with given arguments + """ + return await run_in_subprocess( [str(self.sevctl_path), *args], check=True, diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index 71bc2b53..e5d8afcf 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Dict, Protocol, TypeVar +from pydantic import BaseModel + __all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage") from aleph_message.models import AlephMessage @@ -39,3 +41,20 @@ async def sign_raw(self, buffer: bytes) -> bytes: ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) + + +class SEVInfo(BaseModel): + """ + An AMD SEV platform information. + """ + + enabled: bool + + +class SEVMeasurement(BaseModel): + """ + A SEV measurement data get from Qemu measurement. + """ + + sev_info: SEVInfo + launch_measure: str diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index 130edc38..db113d3a 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -1,6 +1,8 @@ import asyncio +import base64 import errno import hashlib +import hmac import json import logging import os @@ -23,16 +25,19 @@ Union, get_args, ) +from uuid import UUID from zipfile import BadZipFile, ZipFile from aleph_message.models import ItemHash, MessageType from aleph_message.models.execution.program import Encoding from aleph_message.models.execution.volume import MachineVolume +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from jwcrypto.jwa import JWA from pydantic.json import pydantic_encoder from aleph.sdk.conf import settings -from aleph.sdk.types import GenericMessage +from aleph.sdk.types import GenericMessage, SEVInfo, SEVMeasurement logger = logging.getLogger(__name__) @@ -251,3 +256,122 @@ async def run_in_subprocess( ) return stdout + + +def get_vm_measure(sev_data: SEVMeasurement) -> Tuple[bytes, bytes]: + launch_measure = base64.b64decode(sev_data.launch_measure) + vm_measure = launch_measure[0:32] + nonce = launch_measure[32:48] + return vm_measure, nonce + + +def compute_confidential_measure( + sev_info: SEVInfo, tik: bytes, expected_hash: str, nonce: bytes +) -> hmac.HMAC: + """ + Computes the SEV measurement using the CRN SEV data and local variables like the OVMF firmware hash, + and the session key generated. + """ + + h = hmac.new(tik, digestmod="sha256") + + ## + # calculated per section 6.5.2 + ## + h.update(bytes([0x04])) + h.update(sev_info.api_major.to_bytes(1, byteorder="little")) + h.update(sev_info.api_minor.to_bytes(1, byteorder="little")) + h.update(sev_info.build_id.to_bytes(1, byteorder="little")) + h.update(sev_info.policy.to_bytes(4, byteorder="little")) + + expected_hash_bytes = bytearray.fromhex(expected_hash) + h.update(expected_hash_bytes) + + h.update(nonce) + + return h + + +def make_secret_table(secret: str) -> bytearray: + """ + Makes the disk secret table to be sent to the Confidential CRN + """ + + ## + # Construct the secret table: two guids + 4 byte lengths plus string + # and zero terminator + # + # Secret layout is guid, len (4 bytes), data + # with len being the length from start of guid to end of data + # + # The table header covers the entire table then each entry covers + # only its local data + # + # our current table has the header guid with total table length + # followed by the secret guid with the zero terminated secret + ## + + # total length of table: header plus one entry with trailing \0 + length = 16 + 4 + 16 + 4 + len(secret) + 1 + # SEV-ES requires rounding to 16 + length = (length + 15) & ~15 + secret_table = bytearray(length) + + secret_table[0:16] = UUID("{1e74f542-71dd-4d66-963e-ef4287ff173b}").bytes_le + secret_table[16:20] = len(secret_table).to_bytes(4, byteorder="little") + secret_table[20:36] = UUID("{736869e5-84f0-4973-92ec-06879ce3da0b}").bytes_le + secret_table[36:40] = (16 + 4 + len(secret) + 1).to_bytes(4, byteorder="little") + secret_table[40 : 40 + len(secret)] = secret.encode() + + return secret_table + + +def encrypt_secret_table(secret_table: bytes, tek: bytes, iv: bytes) -> bytes: + """Encrypt the secret table with the TEK in CTR mode using a random IV""" + + # Initialize the cipher with AES algorithm and CTR mode + cipher = Cipher(algorithms.AES(tek), modes.CTR(iv), backend=default_backend()) + encryptor = cipher.encryptor() + + # Encrypt the secret table + encrypted_secret = encryptor.update(secret_table) + encryptor.finalize() + + return encrypted_secret + + +def make_packet_header( + vm_measure: bytes, + encrypted_secret_table: bytes, + secret_table_size: int, + tik: bytes, + iv: bytes, +) -> bytearray: + """ + Creates a packet header using the encrypted disk secret table to be sent to the Confidential CRN + """ + + ## + # ultimately needs to be an argument, but there's only + # compressed and no real use case + ## + flags = 0 + + ## + # Table 55. LAUNCH_SECRET Packet Header Buffer + ## + header = bytearray(52) + header[0:4] = flags.to_bytes(4, byteorder="little") + header[4:20] = iv + + h = hmac.new(tik, digestmod="sha256") + h.update(bytes([0x01])) + # FLAGS || IV + h.update(header[0:20]) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(secret_table_size.to_bytes(4, byteorder="little")) + h.update(encrypted_secret_table) + h.update(vm_measure) + + header[20:52] = h.digest() + + return header diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index 6dac8836..cc25cf9f 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -31,10 +31,7 @@ async def test_perform_confidential_operation(): payload="mock_response_text", ) - status, response_text = await vm_client.perform_confidential_operation( - vm_id, operation - ) - assert status == 200 + response_text = await vm_client.perform_confidential_operation(vm_id, operation) assert response_text == '"mock_response_text"' # ' ' cause by aioresponses await vm_client.session.close() @@ -70,10 +67,9 @@ async def test_confidential_initialize_instance(): payload="mock_response_text", ) tmp_file_path = Path(tmp_file.name) - status, response_text = await vm_client.initialize( + response_text = await vm_client.initialize( vm_id, session=tmp_file_path, godh=tmp_file_path ) - assert status == 200 assert ( response_text == '"mock_response_text"' ) # ' ' cause by aioresponses @@ -112,17 +108,29 @@ async def test_confidential_measurement_instance(): node_url=node_url, session=aiohttp.ClientSession(), ) - m.post( + m.get( url, status=200, - payload="mock_response_text", + payload=dict( + { + "sev_info": { + "enabled": True, + "api_major": 0, + "api_minor": 0, + "build_id": 0, + "policy": 0, + "state": "", + "handle": 0, + }, + "launch_measure": "test_measure", + } + ), ) - status, response_text = await vm_client.measurement(vm_id) - assert status == 200 - assert response_text == "mock_response_text" + measurement = await vm_client.measurement(vm_id) + assert measurement.launch_measure == "test_measure" m.assert_called_once_with( url, - method="POST", + method="GET", headers=headers, ) await vm_client.session.close() @@ -140,7 +148,7 @@ async def test_confidential_inject_secret_instance(): "X-SignedOperation": "test_operation_token", } test_secret = "test_secret" - packed_header = "test_packed_header" + packet_header = "test_packet_header" with aioresponses() as m: with patch( @@ -158,17 +166,16 @@ async def test_confidential_inject_secret_instance(): status=200, payload="mock_response_text", ) - status, response_text = await vm_client.inject_secret( - vm_id, secret=test_secret, packed_header=packed_header + response_text = await vm_client.inject_secret( + vm_id, secret=test_secret, packet_header=packet_header ) - assert status == 200 assert response_text == "mock_response_text" m.assert_called_once_with( url, method="POST", data={ "secret": test_secret, - "packed_header": packed_header, + "packet_header": packet_header, }, headers=headers, ) @@ -206,49 +213,3 @@ async def test_create_session_command(): ], check=True, ) - - -@pytest.mark.asyncio -async def test_build_secret_command(): - account = ETHAccount(private_key=b"0x" + b"1" * 30) - node_url = "http://localhost" - sevctl_path = Path("/usr/bin/sevctl") - current_path = Path().cwd() - measurement = "test_measurement" - secret = "test_secret" - expected_secret_header_path = current_path / "secret_header.bin" - expected_secret_payload_path = current_path / "secret_payload.bin" - - with mock.patch( - "aleph.sdk.client.vm_confidential_client.run_in_subprocess", - return_value=True, - ) as export_mock: - vm_client = VmConfidentialClient( - account=account, - sevctl_path=sevctl_path, - node_url=node_url, - session=aiohttp.ClientSession(), - ) - secret_header_path, secret_payload_path = await vm_client.build_secret( - current_path, current_path, measurement, secret - ) - assert expected_secret_header_path == secret_header_path - assert expected_secret_payload_path == secret_payload_path - export_mock.assert_called_once_with( - [ - str(sevctl_path), - "secret", - "build", - "--tik", - str(current_path), - "--tek", - str(current_path), - "--launch-measure-blob", - measurement, - "--secret", - secret, - str(expected_secret_header_path), - str(expected_secret_payload_path), - ], - check=True, - ) From d42a4ee70de3cd60df28b4402bc3e3c195795f97 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 15:09:40 +0200 Subject: [PATCH 09/14] Problem: Auth was not working Corrections: * Measurement type returned was missing field needed for validation of measurements * Port number was not handled correctly in authentifaction * Adapt to new auth protocol where domain is moved to the operation field (While keeping compat with the old format) * Get measurement was not working since signed with the wrong method * inject_secret was not sending a json * Websocked auth was sending a twice serialized json --- src/aleph/sdk/client/vm_client.py | 28 +++++++++------ .../sdk/client/vm_confidential_client.py | 35 ++++++++++--------- src/aleph/sdk/types.py | 6 ++++ src/aleph/sdk/utils.py | 7 ++-- 4 files changed, 47 insertions(+), 29 deletions(-) diff --git a/src/aleph/sdk/client/vm_client.py b/src/aleph/sdk/client/vm_client.py index 212b7eb5..35015080 100644 --- a/src/aleph/sdk/client/vm_client.py +++ b/src/aleph/sdk/client/vm_client.py @@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": urlparse(self.node_url).netloc, + "domain": urlparse(self.node_url).hostname, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -65,14 +65,16 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": urlparse(self.node_url).netloc}, + "content": {"domain": urlparse(self.node_url).hostname}, } ) async def _generate_header( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str ) -> Tuple[str, Dict[str, str]]: - payload = create_vm_control_payload(vm_id, operation) + payload = create_vm_control_payload( + vm_id, operation, domain=urlparse(self.node_url).hostname, method=method + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) if not self.pubkey_signature_header: @@ -89,17 +91,21 @@ async def _generate_header( return f"{self.node_url}{path}", headers async def perform_operation( - self, vm_id: ItemHash, operation: str + self, vm_id: ItemHash, operation: str, method: str = "POST" ) -> Tuple[Optional[int], str]: if not self.pubkey_signature_header: self.pubkey_signature_header = ( await self._generate_pubkey_signature_header() ) - url, header = await self._generate_header(vm_id=vm_id, operation=operation) + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method=method + ) try: - async with self.session.post(url, headers=header) as response: + async with self.session.request( + method=method, url=url, headers=header + ) as response: response_text = await response.text() return response.status, response_text @@ -113,7 +119,9 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: await self._generate_pubkey_signature_header() ) - payload = create_vm_control_payload(vm_id, "logs") + payload = create_vm_control_payload( + vm_id, "stream_logs", method="get", domain=urlparse(self.node_url).hostname + ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] ws_url = f"{self.node_url}{path}" @@ -121,8 +129,8 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: async with self.session.ws_connect(ws_url) as ws: auth_message = { "auth": { - "X-SignedPubKey": self.pubkey_signature_header, - "X-SignedOperation": signed_operation, + "X-SignedPubKey": json.loads(self.pubkey_signature_header), + "X-SignedOperation": json.loads(signed_operation), } } await ws.send_json(auth_message) diff --git a/src/aleph/sdk/client/vm_confidential_client.py b/src/aleph/sdk/client/vm_confidential_client.py index 7373f957..a100de8c 100644 --- a/src/aleph/sdk/client/vm_confidential_client.py +++ b/src/aleph/sdk/client/vm_confidential_client.py @@ -102,19 +102,11 @@ async def measurement(self, vm_id: ItemHash) -> SEVMeasurement: await self._generate_pubkey_signature_header() ) - operation = "confidential/measurement" - url, header = await self._generate_header(vm_id=vm_id, operation=operation) - try: - async with self.session.get(url, headers=header) as response: - response = await response.json() - print(response) - sev_mesurement = SEVMeasurement.parse_obj(response) - return sev_mesurement - - except aiohttp.ClientError as e: - raise ValueError( - f"HTTP error getting node certificates on {self.node_url}: {str(e)}" - ) + status, text = await self.perform_operation( + vm_id, "confidential/measurement", method="GET" + ) + sev_mesurement = SEVMeasurement.parse_raw(text) + return sev_mesurement async def validate_measure( self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str @@ -177,13 +169,17 @@ async def inject_secret( "secret": secret, } text = await self.perform_confidential_operation( - vm_id, "confidential/inject_secret", params=params + vm_id, "confidential/inject_secret", json=params ) return json.loads(text) async def perform_confidential_operation( - self, vm_id: ItemHash, operation: str, params: Optional[Dict[str, Any]] = None + self, + vm_id: ItemHash, + operation: str, + params: Optional[Dict[str, Any]] = None, + json=None, ) -> str: """ Send confidential operations to the CRN passing the auth headers on each request @@ -194,10 +190,15 @@ async def perform_confidential_operation( await self._generate_pubkey_signature_header() ) - url, header = await self._generate_header(vm_id=vm_id, operation=operation) + url, header = await self._generate_header( + vm_id=vm_id, operation=operation, method="post" + ) try: - async with self.session.post(url, headers=header, data=params) as response: + async with self.session.post( + url, headers=header, data=params, json=json + ) as response: + response.raise_for_status() response_text = await response.text() return response_text diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index e5d8afcf..cf9e6fa8 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -49,6 +49,12 @@ class SEVInfo(BaseModel): """ enabled: bool + api_major: int + api_minor: int + build_id: int + policy: int + state: str + handle: int class SEVMeasurement(BaseModel): diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index db113d3a..5c641d5c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -208,12 +208,15 @@ def bytes_from_hex(hex_string: str) -> bytes: return hex_string -def create_vm_control_payload(vm_id: ItemHash, operation: str) -> Dict[str, str]: +def create_vm_control_payload( + vm_id: ItemHash, operation: str, domain: str, method: str +) -> Dict[str, str]: path = f"/control/machine/{vm_id}/{operation}" payload = { "time": datetime.utcnow().isoformat() + "Z", - "method": "POST", + "method": method.upper(), "path": path, + "domain": domain, } return payload From e852717077f5befe4f9cadaa01e530b6db673f06 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 17:29:17 +0200 Subject: [PATCH 10/14] mypy --- src/aleph/sdk/client/vm_client.py | 15 +++++++++++---- tests/unit/test_vm_client.py | 2 +- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/aleph/sdk/client/vm_client.py b/src/aleph/sdk/client/vm_client.py index 35015080..4092851d 100644 --- a/src/aleph/sdk/client/vm_client.py +++ b/src/aleph/sdk/client/vm_client.py @@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]: return { "pubkey": json.loads(self.ephemeral_key.export_public()), "alg": "ECDSA", - "domain": urlparse(self.node_url).hostname, + "domain": self.node_domain, "address": self.account.get_address(), "expires": ( datetime.datetime.utcnow() + datetime.timedelta(days=1) @@ -65,7 +65,7 @@ async def _generate_pubkey_signature_header(self) -> str: "sender": self.account.get_address(), "payload": pubkey_payload, "signature": pubkey_signature, - "content": {"domain": urlparse(self.node_url).hostname}, + "content": {"domain": self.node_domain}, } ) @@ -73,7 +73,7 @@ async def _generate_header( self, vm_id: ItemHash, operation: str, method: str ) -> Tuple[str, Dict[str, str]]: payload = create_vm_control_payload( - vm_id, operation, domain=urlparse(self.node_url).hostname, method=method + vm_id, operation, domain=self.node_domain, method=method ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) @@ -90,6 +90,13 @@ async def _generate_header( path = payload["path"] return f"{self.node_url}{path}", headers + @property + def node_domain(self) -> str: + domain = urlparse(self.node_url).hostname + if not domain: + raise Exception("Could not parse node domain") + return domain + async def perform_operation( self, vm_id: ItemHash, operation: str, method: str = "POST" ) -> Tuple[Optional[int], str]: @@ -120,7 +127,7 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]: ) payload = create_vm_control_payload( - vm_id, "stream_logs", method="get", domain=urlparse(self.node_url).hostname + vm_id, "stream_logs", method="get", domain=self.node_domain ) signed_operation = sign_vm_control_payload(payload, self.ephemeral_key) path = payload["path"] diff --git a/tests/unit/test_vm_client.py b/tests/unit/test_vm_client.py index 6dcd9fdb..ac516564 100644 --- a/tests/unit/test_vm_client.py +++ b/tests/unit/test_vm_client.py @@ -290,7 +290,7 @@ async def test_vm_client_generate_correct_authentication_headers(): session=aiohttp.ClientSession(), ) - path, headers = await vm_client._generate_header(vm_id, "reboot") + path, headers = await vm_client._generate_header(vm_id, "reboot", method="post") signed_pubkey = SignedPubKeyHeader.parse_raw(headers["X-SignedPubKey"]) signed_operation = SignedOperation.parse_raw(headers["X-SignedOperation"]) address = verify_signed_operation(signed_operation, signed_pubkey) From dd2639a5402ae572228f7438ccbc961c5751ac4f Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 21:41:48 +0200 Subject: [PATCH 11/14] Fix tests update 'vendorized' aleph-vm auth file from source --- tests/unit/aleph_vm_authentication.py | 40 +++++++++------------------ tests/unit/test_vm_client.py | 25 ++++++++--------- 2 files changed, 25 insertions(+), 40 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 7d213547..c7348685 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -45,11 +45,10 @@ def verify_wallet_signature(signature: bytes, message: str, address: str) -> boo class SignedPubKeyPayload(BaseModel): """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" - pubkey: Dict[str, Any] + pubkey: dict[str, Any] # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} # alg: Literal["ECDSA"] - domain: str address: str expires: str @@ -77,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: return bytes_from_hex(value.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: + def check_expiry(cls, values) -> dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -104,18 +103,18 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]: @property def content(self) -> SignedPubKeyPayload: """Return the content of the header""" - return SignedPubKeyPayload.parse_raw(self.payload) class SignedOperationPayload(BaseModel): time: datetime.datetime method: Union[Literal["POST"], Literal["GET"]] + domain: str path: str # body_sha256: str # disabled since there is no body @validator("time") - def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: + def time_is_current(cls, v: datetime.datetime) -> datetime.datetime: """Check that the time is current and the payload is not a replay attack.""" max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta( minutes=2 @@ -123,14 +122,11 @@ def time_is_current(cls, value: datetime.datetime) -> datetime.datetime: max_future = datetime.datetime.now( tz=datetime.timezone.utc ) + datetime.timedelta(minutes=2) - - if value < max_past: + if v < max_past: raise ValueError("Time is too far in the past") - - if value > max_future: + if v > max_future: raise ValueError("Time is too far in the future") - - return value + return v class SignedOperation(BaseModel): @@ -152,12 +148,10 @@ def signature_must_be_hex(cls, value: str) -> bytes: raise error @validator("payload") - def payload_must_be_hex(cls, value: bytes) -> bytes: + def payload_must_be_hex(cls, v) -> bytes: """Convert the payload from hexadecimal to bytes""" - - v = bytes_from_hex(value.decode()) + v = bytes.fromhex(v.decode()) _ = SignedOperationPayload.parse_raw(v) - return v @property @@ -197,7 +191,6 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader: if str(err.exc) == "Invalid signature": raise web.HTTPUnauthorized(reason="Invalid signature") from errors - else: raise errors @@ -207,13 +200,10 @@ def get_signed_operation(request: web.Request) -> SignedOperation: try: signed_operation = request.headers["X-SignedOperation"] return SignedOperation.parse_raw(signed_operation) - except KeyError as error: raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error - except json.JSONDecodeError as error: raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error - except ValidationError as error: logger.debug(f"Invalid X-SignedOperation fields: {error}") raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error @@ -244,9 +234,9 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") @@ -255,13 +245,11 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) f"Invalid path '{signed_operation.content.path}' != '{request.path}'" ) raise web.HTTPUnauthorized(reason="Invalid path") - if signed_operation.content.method != request.method: logger.debug( f"Invalid method '{signed_operation.content.method}' != '{request.method}'" ) raise web.HTTPUnauthorized(reason="Invalid method") - return verify_signed_operation(signed_operation, signed_pubkey) @@ -271,20 +259,17 @@ async def authenticate_websocket_message( """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) - - if signed_pubkey.content.domain != domain_name: + if signed_operation.content.domain != domain_name: logger.debug( f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") - return verify_signed_operation(signed_operation, signed_pubkey) def require_jwk_authentication( handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]] ) -> Callable[[web.Request], Awaitable[web.StreamResponse]]: - @functools.wraps(handler) async def wrapper(request): try: @@ -296,6 +281,7 @@ async def wrapper(request): logging.exception(e) raise + # authenticated_sender is the authenticted wallet address of the requester (as a string) response = await handler(request, authenticated_sender) return response diff --git a/tests/unit/test_vm_client.py b/tests/unit/test_vm_client.py index ac516564..7cc9a2c3 100644 --- a/tests/unit/test_vm_client.py +++ b/tests/unit/test_vm_client.py @@ -1,4 +1,3 @@ -import json from urllib.parse import urlparse import aiohttp @@ -173,7 +172,7 @@ async def websocket_handler(request): app = web.Application() app.router.add_route( - "GET", "/control/machine/{vm_id}/logs", websocket_handler + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler ) # Update route to match the URL client = await aiohttp_client(app) @@ -202,7 +201,9 @@ async def test_authenticate_jwk(aiohttp_client): vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe") async def test_authenticate_route(request): - address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc) + address = await authenticate_jwk( + request, domain_name=urlparse(node_url).hostname + ) assert vm_client.account.get_address() == address return web.Response(text="ok") @@ -222,7 +223,7 @@ async def test_authenticate_route(request): ) status_code, response_text = await vm_client.stop_instance(vm_id) - assert status_code == 200 + assert status_code == 200, response_text assert response_text == "ok" await vm_client.session.close() @@ -239,22 +240,19 @@ async def websocket_handler(request): first_message = await ws.receive_json() credentials = first_message["auth"] - address = await authenticate_websocket_message( - { - "X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]), - "X-SignedOperation": json.loads(credentials["X-SignedOperation"]), - }, - domain_name=urlparse(node_url).netloc, + sender_address = await authenticate_websocket_message( + credentials, + domain_name=urlparse(node_url).hostname, ) - assert vm_client.account.get_address() == address - await ws.send_str(address) + assert vm_client.account.get_address() == sender_address + await ws.send_str(sender_address) return ws app = web.Application() app.router.add_route( - "GET", "/control/machine/{vm_id}/logs", websocket_handler + "GET", "/control/machine/{vm_id}/stream_logs", websocket_handler ) # Update route to match the URL client = await aiohttp_client(app) @@ -268,6 +266,7 @@ async def websocket_handler(request): ) valid = False + async for address in vm_client.get_logs(vm_id): assert address == vm_client.account.get_address() valid = True From 04d31f7f16e3a581a3a6970a21d5e1484fdaa4d3 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 21:56:43 +0200 Subject: [PATCH 12/14] mypy --- tests/unit/aleph_vm_authentication.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index c7348685..46faee1e 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Awaitable, Coroutine -from typing import Any, Callable, Dict, Literal, Union +from typing import Any, Callable, Dict, Literal, Union, Optional import cryptography.exceptions import pydantic @@ -45,7 +45,7 @@ def verify_wallet_signature(signature: bytes, message: str, address: str) -> boo class SignedPubKeyPayload(BaseModel): """This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf.""" - pubkey: dict[str, Any] + pubkey: Dict[str, Any] # {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', # 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'} # alg: Literal["ECDSA"] @@ -76,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes: return bytes_from_hex(value.decode()) @root_validator(pre=False, skip_on_failure=True) - def check_expiry(cls, values) -> dict[str, bytes]: + def check_expiry(cls, values) -> Dict[str, bytes]: """Check that the token has not expired""" payload: bytes = values["payload"] content = SignedPubKeyPayload.parse_raw(payload) @@ -229,7 +229,9 @@ def verify_signed_operation( raise web.HTTPUnauthorized(reason="Signature could not verified") -async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) -> str: +async def authenticate_jwk( + request: web.Request, domain_name: Optional[str] = DOMAIN_NAME +) -> str: """Authenticate a request using the X-SignedPubKey and X-SignedOperation headers.""" signed_pubkey = get_signed_pubkey(request) signed_operation = get_signed_operation(request) @@ -254,7 +256,7 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME) async def authenticate_websocket_message( - message, domain_name: str = DOMAIN_NAME + message, domain_name: Optional[str] = DOMAIN_NAME ) -> str: """Authenticate a websocket message since JS cannot configure headers on WebSockets.""" signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"]) From 1bbce238904b054b231964b617e830a708c022e7 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 21:59:56 +0200 Subject: [PATCH 13/14] isort --- tests/unit/aleph_vm_authentication.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 46faee1e..491da51a 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -4,7 +4,7 @@ import json import logging from collections.abc import Awaitable, Coroutine -from typing import Any, Callable, Dict, Literal, Union, Optional +from typing import Any, Callable, Dict, Literal, Optional, Union import cryptography.exceptions import pydantic From 196fd92deb857b60fecde721e9565cd8f95d5a48 Mon Sep 17 00:00:00 2001 From: Olivier Le Thanh Duong Date: Thu, 4 Jul 2024 22:19:02 +0200 Subject: [PATCH 14/14] Fix confidential tests --- tests/unit/test_vm_confidential_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_vm_confidential_client.py b/tests/unit/test_vm_confidential_client.py index cc25cf9f..832871ff 100644 --- a/tests/unit/test_vm_confidential_client.py +++ b/tests/unit/test_vm_confidential_client.py @@ -80,6 +80,7 @@ async def test_confidential_initialize_instance(): "session": tmp_file_bytes, "godh": tmp_file_bytes, }, + json=None, headers=headers, ) await vm_client.session.close() @@ -173,7 +174,7 @@ async def test_confidential_inject_secret_instance(): m.assert_called_once_with( url, method="POST", - data={ + json={ "secret": test_secret, "packet_header": packet_header, },