Skip to content

Commit 247dbfc

Browse files
nesitorolethanh
andauthored
Implement VmConfidentialClient class (#138)
* Problem: A user cannot initialize an already created confidential VM. Solution: Implement `VmConfidentialClient` class to be able to initialize and interact with confidential VMs. * Problem: Auth was not working Corrections: * Measurement type returned was missing field needed for validation of measurements * Port number was not handled correctly in authentifaction * Adapt to new auth protocol where domain is moved to the operation field (While keeping compat with the old format) * Get measurement was not working since signed with the wrong method * inject_secret was not sending a json * Websocked auth was sending a twice serialized json * update 'vendorized' aleph-vm auth file from source Co-authored-by: Olivier Le Thanh Duong <[email protected]>
1 parent d9b1892 commit 247dbfc

File tree

8 files changed

+673
-55
lines changed

8 files changed

+673
-55
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies = [
3232
"python-magic",
3333
"typer",
3434
"typing_extensions",
35+
"aioresponses>=0.7.6"
3536
]
3637

3738
[project.optional-dependencies]

src/aleph/sdk/client/vmclient.py renamed to src/aleph/sdk/client/vm_client.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _generate_pubkey_payload(self) -> Dict[str, Any]:
4444
return {
4545
"pubkey": json.loads(self.ephemeral_key.export_public()),
4646
"alg": "ECDSA",
47-
"domain": urlparse(self.node_url).netloc,
47+
"domain": self.node_domain,
4848
"address": self.account.get_address(),
4949
"expires": (
5050
datetime.datetime.utcnow() + datetime.timedelta(days=1)
@@ -65,14 +65,16 @@ async def _generate_pubkey_signature_header(self) -> str:
6565
"sender": self.account.get_address(),
6666
"payload": pubkey_payload,
6767
"signature": pubkey_signature,
68-
"content": {"domain": urlparse(self.node_url).netloc},
68+
"content": {"domain": self.node_domain},
6969
}
7070
)
7171

7272
async def _generate_header(
73-
self, vm_id: ItemHash, operation: str
73+
self, vm_id: ItemHash, operation: str, method: str
7474
) -> Tuple[str, Dict[str, str]]:
75-
payload = create_vm_control_payload(vm_id, operation)
75+
payload = create_vm_control_payload(
76+
vm_id, operation, domain=self.node_domain, method=method
77+
)
7678
signed_operation = sign_vm_control_payload(payload, self.ephemeral_key)
7779

7880
if not self.pubkey_signature_header:
@@ -88,18 +90,29 @@ async def _generate_header(
8890
path = payload["path"]
8991
return f"{self.node_url}{path}", headers
9092

93+
@property
94+
def node_domain(self) -> str:
95+
domain = urlparse(self.node_url).hostname
96+
if not domain:
97+
raise Exception("Could not parse node domain")
98+
return domain
99+
91100
async def perform_operation(
92-
self, vm_id: ItemHash, operation: str
101+
self, vm_id: ItemHash, operation: str, method: str = "POST"
93102
) -> Tuple[Optional[int], str]:
94103
if not self.pubkey_signature_header:
95104
self.pubkey_signature_header = (
96105
await self._generate_pubkey_signature_header()
97106
)
98107

99-
url, header = await self._generate_header(vm_id=vm_id, operation=operation)
108+
url, header = await self._generate_header(
109+
vm_id=vm_id, operation=operation, method=method
110+
)
100111

101112
try:
102-
async with self.session.post(url, headers=header) as response:
113+
async with self.session.request(
114+
method=method, url=url, headers=header
115+
) as response:
103116
response_text = await response.text()
104117
return response.status, response_text
105118

@@ -113,16 +126,18 @@ async def get_logs(self, vm_id: ItemHash) -> AsyncGenerator[str, None]:
113126
await self._generate_pubkey_signature_header()
114127
)
115128

116-
payload = create_vm_control_payload(vm_id, "stream_logs")
129+
payload = create_vm_control_payload(
130+
vm_id, "stream_logs", method="get", domain=self.node_domain
131+
)
117132
signed_operation = sign_vm_control_payload(payload, self.ephemeral_key)
118133
path = payload["path"]
119134
ws_url = f"{self.node_url}{path}"
120135

121136
async with self.session.ws_connect(ws_url) as ws:
122137
auth_message = {
123138
"auth": {
124-
"X-SignedPubKey": self.pubkey_signature_header,
125-
"X-SignedOperation": signed_operation,
139+
"X-SignedPubKey": json.loads(self.pubkey_signature_header),
140+
"X-SignedOperation": json.loads(signed_operation),
126141
}
127142
}
128143
await ws.send_json(auth_message)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import base64
2+
import json
3+
import logging
4+
import os
5+
import tempfile
6+
from pathlib import Path
7+
from typing import Any, Dict, Optional, Tuple
8+
9+
import aiohttp
10+
from aleph_message.models import ItemHash
11+
12+
from aleph.sdk.client.vm_client import VmClient
13+
from aleph.sdk.types import Account, SEVMeasurement
14+
from aleph.sdk.utils import (
15+
compute_confidential_measure,
16+
encrypt_secret_table,
17+
get_vm_measure,
18+
make_packet_header,
19+
make_secret_table,
20+
run_in_subprocess,
21+
)
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class VmConfidentialClient(VmClient):
27+
sevctl_path: Path
28+
29+
def __init__(
30+
self,
31+
account: Account,
32+
sevctl_path: Path,
33+
node_url: str = "",
34+
session: Optional[aiohttp.ClientSession] = None,
35+
):
36+
super().__init__(account, node_url, session)
37+
self.sevctl_path = sevctl_path
38+
39+
async def get_certificates(self) -> Tuple[Optional[int], str]:
40+
"""
41+
Get platform confidential certificate
42+
"""
43+
44+
url = f"{self.node_url}/about/certificates"
45+
try:
46+
async with self.session.get(url) as response:
47+
data = await response.read()
48+
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
49+
tmp_file.write(data)
50+
return response.status, tmp_file.name
51+
52+
except aiohttp.ClientError as e:
53+
logger.error(
54+
f"HTTP error getting node certificates on {self.node_url}: {str(e)}"
55+
)
56+
return None, str(e)
57+
58+
async def create_session(
59+
self, vm_id: ItemHash, certificate_path: Path, policy: int
60+
) -> Path:
61+
"""
62+
Create new confidential session
63+
"""
64+
65+
current_path = Path().cwd()
66+
args = [
67+
"session",
68+
"--name",
69+
str(vm_id),
70+
str(certificate_path),
71+
str(policy),
72+
]
73+
try:
74+
# TODO: Check command result
75+
await self.sevctl_cmd(*args)
76+
return current_path
77+
except Exception as e:
78+
raise ValueError(f"Session creation have failed, reason: {str(e)}")
79+
80+
async def initialize(self, vm_id: ItemHash, session: Path, godh: Path) -> str:
81+
"""
82+
Initialize Confidential VM negociation passing the needed session files
83+
"""
84+
85+
session_file = session.read_bytes()
86+
godh_file = godh.read_bytes()
87+
params = {
88+
"session": session_file,
89+
"godh": godh_file,
90+
}
91+
return await self.perform_confidential_operation(
92+
vm_id, "confidential/initialize", params=params
93+
)
94+
95+
async def measurement(self, vm_id: ItemHash) -> SEVMeasurement:
96+
"""
97+
Fetch VM confidential measurement
98+
"""
99+
100+
if not self.pubkey_signature_header:
101+
self.pubkey_signature_header = (
102+
await self._generate_pubkey_signature_header()
103+
)
104+
105+
status, text = await self.perform_operation(
106+
vm_id, "confidential/measurement", method="GET"
107+
)
108+
sev_mesurement = SEVMeasurement.parse_raw(text)
109+
return sev_mesurement
110+
111+
async def validate_measure(
112+
self, sev_data: SEVMeasurement, tik_path: Path, firmware_hash: str
113+
) -> bool:
114+
"""
115+
Validate VM confidential measurement
116+
"""
117+
118+
tik = tik_path.read_bytes()
119+
vm_measure, nonce = get_vm_measure(sev_data)
120+
121+
expected_measure = compute_confidential_measure(
122+
sev_info=sev_data.sev_info,
123+
tik=tik,
124+
expected_hash=firmware_hash,
125+
nonce=nonce,
126+
).digest()
127+
return expected_measure == vm_measure
128+
129+
async def build_secret(
130+
self, tek_path: Path, tik_path: Path, sev_data: SEVMeasurement, secret: str
131+
) -> Tuple[str, str]:
132+
"""
133+
Build disk secret to be injected on the confidential VM
134+
"""
135+
136+
tek = tek_path.read_bytes()
137+
tik = tik_path.read_bytes()
138+
139+
vm_measure, _ = get_vm_measure(sev_data)
140+
141+
iv = os.urandom(16)
142+
secret_table = make_secret_table(secret)
143+
encrypted_secret_table = encrypt_secret_table(
144+
secret_table=secret_table, tek=tek, iv=iv
145+
)
146+
147+
packet_header = make_packet_header(
148+
vm_measure=vm_measure,
149+
encrypted_secret_table=encrypted_secret_table,
150+
secret_table_size=len(secret_table),
151+
tik=tik,
152+
iv=iv,
153+
)
154+
155+
encoded_packet_header = base64.b64encode(packet_header).decode()
156+
encoded_secret = base64.b64encode(encrypted_secret_table).decode()
157+
158+
return encoded_packet_header, encoded_secret
159+
160+
async def inject_secret(
161+
self, vm_id: ItemHash, packet_header: str, secret: str
162+
) -> Dict:
163+
"""
164+
Send the secret by the encrypted channel to boot up the VM
165+
"""
166+
167+
params = {
168+
"packet_header": packet_header,
169+
"secret": secret,
170+
}
171+
text = await self.perform_confidential_operation(
172+
vm_id, "confidential/inject_secret", json=params
173+
)
174+
175+
return json.loads(text)
176+
177+
async def perform_confidential_operation(
178+
self,
179+
vm_id: ItemHash,
180+
operation: str,
181+
params: Optional[Dict[str, Any]] = None,
182+
json=None,
183+
) -> str:
184+
"""
185+
Send confidential operations to the CRN passing the auth headers on each request
186+
"""
187+
188+
if not self.pubkey_signature_header:
189+
self.pubkey_signature_header = (
190+
await self._generate_pubkey_signature_header()
191+
)
192+
193+
url, header = await self._generate_header(
194+
vm_id=vm_id, operation=operation, method="post"
195+
)
196+
197+
try:
198+
async with self.session.post(
199+
url, headers=header, data=params, json=json
200+
) as response:
201+
response.raise_for_status()
202+
response_text = await response.text()
203+
return response_text
204+
205+
except aiohttp.ClientError as e:
206+
raise ValueError(f"HTTP error during operation {operation}: {str(e)}")
207+
208+
async def sevctl_cmd(self, *args) -> bytes:
209+
"""
210+
Execute `sevctl` command with given arguments
211+
"""
212+
213+
return await run_in_subprocess(
214+
[str(self.sevctl_path), *args],
215+
check=True,
216+
)

src/aleph/sdk/types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from enum import Enum
33
from typing import Dict, Protocol, TypeVar
44

5+
from pydantic import BaseModel
6+
57
__all__ = ("StorageEnum", "Account", "AccountFromPrivateKey", "GenericMessage")
68

79
from aleph_message.models import AlephMessage
@@ -39,3 +41,26 @@ async def sign_raw(self, buffer: bytes) -> bytes: ...
3941

4042

4143
GenericMessage = TypeVar("GenericMessage", bound=AlephMessage)
44+
45+
46+
class SEVInfo(BaseModel):
47+
"""
48+
An AMD SEV platform information.
49+
"""
50+
51+
enabled: bool
52+
api_major: int
53+
api_minor: int
54+
build_id: int
55+
policy: int
56+
state: str
57+
handle: int
58+
59+
60+
class SEVMeasurement(BaseModel):
61+
"""
62+
A SEV measurement data get from Qemu measurement.
63+
"""
64+
65+
sev_info: SEVInfo
66+
launch_measure: str

0 commit comments

Comments
 (0)