Skip to content

Commit fba8bc1

Browse files
committed
Fix tests
update 'vendorized' aleph-vm auth file from source
1 parent e852717 commit fba8bc1

File tree

3 files changed

+32
-44
lines changed

3 files changed

+32
-44
lines changed

tests/unit/aleph_vm_authentication.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import logging
66
from collections.abc import Awaitable, Coroutine
7-
from typing import Any, Callable, Dict, Literal, Union
7+
from typing import Any, Callable, Literal, Union
88

99
import cryptography.exceptions
1010
import pydantic
@@ -45,11 +45,10 @@ def verify_wallet_signature(signature: bytes, message: str, address: str) -> boo
4545
class SignedPubKeyPayload(BaseModel):
4646
"""This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf."""
4747

48-
pubkey: Dict[str, Any]
48+
pubkey: dict[str, Any]
4949
# {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC',
5050
# 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'}
5151
# alg: Literal["ECDSA"]
52-
domain: str
5352
address: str
5453
expires: str
5554

@@ -77,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes:
7776
return bytes_from_hex(value.decode())
7877

7978
@root_validator(pre=False, skip_on_failure=True)
80-
def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]:
79+
def check_expiry(cls, values) -> dict[str, bytes]:
8180
"""Check that the token has not expired"""
8281
payload: bytes = values["payload"]
8382
content = SignedPubKeyPayload.parse_raw(payload)
@@ -104,33 +103,30 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]:
104103
@property
105104
def content(self) -> SignedPubKeyPayload:
106105
"""Return the content of the header"""
107-
108106
return SignedPubKeyPayload.parse_raw(self.payload)
109107

110108

111109
class SignedOperationPayload(BaseModel):
112110
time: datetime.datetime
113111
method: Union[Literal["POST"], Literal["GET"]]
112+
domain: str
114113
path: str
115114
# body_sha256: str # disabled since there is no body
116115

117116
@validator("time")
118-
def time_is_current(cls, value: datetime.datetime) -> datetime.datetime:
117+
def time_is_current(cls, v: datetime.datetime) -> datetime.datetime:
119118
"""Check that the time is current and the payload is not a replay attack."""
120119
max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(
121120
minutes=2
122121
)
123122
max_future = datetime.datetime.now(
124123
tz=datetime.timezone.utc
125124
) + datetime.timedelta(minutes=2)
126-
127-
if value < max_past:
125+
if v < max_past:
128126
raise ValueError("Time is too far in the past")
129-
130-
if value > max_future:
127+
if v > max_future:
131128
raise ValueError("Time is too far in the future")
132-
133-
return value
129+
return v
134130

135131

136132
class SignedOperation(BaseModel):
@@ -152,12 +148,10 @@ def signature_must_be_hex(cls, value: str) -> bytes:
152148
raise error
153149

154150
@validator("payload")
155-
def payload_must_be_hex(cls, value: bytes) -> bytes:
151+
def payload_must_be_hex(cls, v) -> bytes:
156152
"""Convert the payload from hexadecimal to bytes"""
157-
158-
v = bytes_from_hex(value.decode())
153+
v = bytes.fromhex(v.decode())
159154
_ = SignedOperationPayload.parse_raw(v)
160-
161155
return v
162156

163157
@property
@@ -197,7 +191,6 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader:
197191

198192
if str(err.exc) == "Invalid signature":
199193
raise web.HTTPUnauthorized(reason="Invalid signature") from errors
200-
201194
else:
202195
raise errors
203196

@@ -207,13 +200,10 @@ def get_signed_operation(request: web.Request) -> SignedOperation:
207200
try:
208201
signed_operation = request.headers["X-SignedOperation"]
209202
return SignedOperation.parse_raw(signed_operation)
210-
211203
except KeyError as error:
212204
raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error
213-
214205
except json.JSONDecodeError as error:
215206
raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error
216-
217207
except ValidationError as error:
218208
logger.debug(f"Invalid X-SignedOperation fields: {error}")
219209
raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error
@@ -244,9 +234,9 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME)
244234
signed_pubkey = get_signed_pubkey(request)
245235
signed_operation = get_signed_operation(request)
246236

247-
if signed_pubkey.content.domain != domain_name:
237+
if signed_operation.content.domain != domain_name:
248238
logger.debug(
249-
f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'"
239+
f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'"
250240
)
251241
raise web.HTTPUnauthorized(reason="Invalid domain")
252242

@@ -255,13 +245,11 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME)
255245
f"Invalid path '{signed_operation.content.path}' != '{request.path}'"
256246
)
257247
raise web.HTTPUnauthorized(reason="Invalid path")
258-
259248
if signed_operation.content.method != request.method:
260249
logger.debug(
261250
f"Invalid method '{signed_operation.content.method}' != '{request.method}'"
262251
)
263252
raise web.HTTPUnauthorized(reason="Invalid method")
264-
265253
return verify_signed_operation(signed_operation, signed_pubkey)
266254

267255

@@ -271,20 +259,17 @@ async def authenticate_websocket_message(
271259
"""Authenticate a websocket message since JS cannot configure headers on WebSockets."""
272260
signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"])
273261
signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"])
274-
275-
if signed_pubkey.content.domain != domain_name:
262+
if signed_operation.content.domain != domain_name:
276263
logger.debug(
277264
f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'"
278265
)
279266
raise web.HTTPUnauthorized(reason="Invalid domain")
280-
281267
return verify_signed_operation(signed_operation, signed_pubkey)
282268

283269

284270
def require_jwk_authentication(
285271
handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]]
286272
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
287-
288273
@functools.wraps(handler)
289274
async def wrapper(request):
290275
try:
@@ -296,6 +281,7 @@ async def wrapper(request):
296281
logging.exception(e)
297282
raise
298283

284+
# authenticated_sender is the authenticted wallet address of the requester (as a string)
299285
response = await handler(request, authenticated_sender)
300286
return response
301287

tests/unit/conftest.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from aleph_message.models import AggregateMessage, AlephMessage, PostMessage
1212

1313
import aleph.sdk.chains.ethereum as ethereum
14+
1415
import aleph.sdk.chains.sol as solana
15-
import aleph.sdk.chains.substrate as substrate
16-
import aleph.sdk.chains.tezos as tezos
16+
17+
# import aleph.sdk.chains.substrate as substrate
18+
# import aleph.sdk.chains.tezos as tezos
1719
from aleph.sdk import AlephHttpClient, AuthenticatedAlephHttpClient
1820
from aleph.sdk.chains.common import get_fallback_private_key
1921
from aleph.sdk.types import Account
@@ -40,14 +42,14 @@ def solana_account() -> solana.SOLAccount:
4042

4143

4244
@pytest.fixture
43-
def tezos_account() -> tezos.TezosAccount:
45+
def tezos_account() -> "tezos.TezosAccount":
4446
with NamedTemporaryFile(delete=False) as private_key_file:
4547
private_key_file.close()
4648
yield tezos.get_fallback_account(path=Path(private_key_file.name))
4749

4850

4951
@pytest.fixture
50-
def substrate_account() -> substrate.DOTAccount:
52+
def substrate_account() -> "substrate.DOTAccount":
5153
with NamedTemporaryFile(delete=False) as private_key_file:
5254
private_key_file.close()
5355
yield substrate.get_fallback_account(path=Path(private_key_file.name))

tests/unit/test_vm_client.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ async def websocket_handler(request):
173173

174174
app = web.Application()
175175
app.router.add_route(
176-
"GET", "/control/machine/{vm_id}/logs", websocket_handler
176+
"GET", "/control/machine/{vm_id}/stream_logs", websocket_handler
177177
) # Update route to match the URL
178178

179179
client = await aiohttp_client(app)
@@ -202,7 +202,9 @@ async def test_authenticate_jwk(aiohttp_client):
202202
vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe")
203203

204204
async def test_authenticate_route(request):
205-
address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc)
205+
address = await authenticate_jwk(
206+
request, domain_name=urlparse(node_url).hostname
207+
)
206208
assert vm_client.account.get_address() == address
207209
return web.Response(text="ok")
208210

@@ -222,7 +224,7 @@ async def test_authenticate_route(request):
222224
)
223225

224226
status_code, response_text = await vm_client.stop_instance(vm_id)
225-
assert status_code == 200
227+
assert status_code == 200, response_text
226228
assert response_text == "ok"
227229

228230
await vm_client.session.close()
@@ -239,22 +241,19 @@ async def websocket_handler(request):
239241

240242
first_message = await ws.receive_json()
241243
credentials = first_message["auth"]
242-
address = await authenticate_websocket_message(
243-
{
244-
"X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]),
245-
"X-SignedOperation": json.loads(credentials["X-SignedOperation"]),
246-
},
247-
domain_name=urlparse(node_url).netloc,
244+
sender_address = await authenticate_websocket_message(
245+
credentials,
246+
domain_name=urlparse(node_url).hostname,
248247
)
249248

250-
assert vm_client.account.get_address() == address
251-
await ws.send_str(address)
249+
assert vm_client.account.get_address() == sender_address
250+
await ws.send_str(sender_address)
252251

253252
return ws
254253

255254
app = web.Application()
256255
app.router.add_route(
257-
"GET", "/control/machine/{vm_id}/logs", websocket_handler
256+
"GET", "/control/machine/{vm_id}/stream_logs", websocket_handler
258257
) # Update route to match the URL
259258

260259
client = await aiohttp_client(app)
@@ -268,6 +267,7 @@ async def websocket_handler(request):
268267
)
269268

270269
valid = False
270+
271271
async for address in vm_client.get_logs(vm_id):
272272
assert address == vm_client.account.get_address()
273273
valid = True

0 commit comments

Comments
 (0)