Skip to content

Commit dd2639a

Browse files
committed
Fix tests
update 'vendorized' aleph-vm auth file from source
1 parent e852717 commit dd2639a

File tree

2 files changed

+25
-40
lines changed

2 files changed

+25
-40
lines changed

tests/unit/aleph_vm_authentication.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ def verify_wallet_signature(signature: bytes, message: str, address: str) -> boo
4545
class SignedPubKeyPayload(BaseModel):
4646
"""This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf."""
4747

48-
pubkey: Dict[str, Any]
48+
pubkey: dict[str, Any]
4949
# {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC',
5050
# 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'}
5151
# alg: Literal["ECDSA"]
52-
domain: str
5352
address: str
5453
expires: str
5554

@@ -77,7 +76,7 @@ def payload_must_be_hex(cls, value: bytes) -> bytes:
7776
return bytes_from_hex(value.decode())
7877

7978
@root_validator(pre=False, skip_on_failure=True)
80-
def check_expiry(cls, values: Dict[str, bytes]) -> Dict[str, bytes]:
79+
def check_expiry(cls, values) -> dict[str, bytes]:
8180
"""Check that the token has not expired"""
8281
payload: bytes = values["payload"]
8382
content = SignedPubKeyPayload.parse_raw(payload)
@@ -104,33 +103,30 @@ def check_signature(cls, values: Dict[str, bytes]) -> Dict[str, bytes]:
104103
@property
105104
def content(self) -> SignedPubKeyPayload:
106105
"""Return the content of the header"""
107-
108106
return SignedPubKeyPayload.parse_raw(self.payload)
109107

110108

111109
class SignedOperationPayload(BaseModel):
112110
time: datetime.datetime
113111
method: Union[Literal["POST"], Literal["GET"]]
112+
domain: str
114113
path: str
115114
# body_sha256: str # disabled since there is no body
116115

117116
@validator("time")
118-
def time_is_current(cls, value: datetime.datetime) -> datetime.datetime:
117+
def time_is_current(cls, v: datetime.datetime) -> datetime.datetime:
119118
"""Check that the time is current and the payload is not a replay attack."""
120119
max_past = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(
121120
minutes=2
122121
)
123122
max_future = datetime.datetime.now(
124123
tz=datetime.timezone.utc
125124
) + datetime.timedelta(minutes=2)
126-
127-
if value < max_past:
125+
if v < max_past:
128126
raise ValueError("Time is too far in the past")
129-
130-
if value > max_future:
127+
if v > max_future:
131128
raise ValueError("Time is too far in the future")
132-
133-
return value
129+
return v
134130

135131

136132
class SignedOperation(BaseModel):
@@ -152,12 +148,10 @@ def signature_must_be_hex(cls, value: str) -> bytes:
152148
raise error
153149

154150
@validator("payload")
155-
def payload_must_be_hex(cls, value: bytes) -> bytes:
151+
def payload_must_be_hex(cls, v) -> bytes:
156152
"""Convert the payload from hexadecimal to bytes"""
157-
158-
v = bytes_from_hex(value.decode())
153+
v = bytes.fromhex(v.decode())
159154
_ = SignedOperationPayload.parse_raw(v)
160-
161155
return v
162156

163157
@property
@@ -197,7 +191,6 @@ def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader:
197191

198192
if str(err.exc) == "Invalid signature":
199193
raise web.HTTPUnauthorized(reason="Invalid signature") from errors
200-
201194
else:
202195
raise errors
203196

@@ -207,13 +200,10 @@ def get_signed_operation(request: web.Request) -> SignedOperation:
207200
try:
208201
signed_operation = request.headers["X-SignedOperation"]
209202
return SignedOperation.parse_raw(signed_operation)
210-
211203
except KeyError as error:
212204
raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error
213-
214205
except json.JSONDecodeError as error:
215206
raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error
216-
217207
except ValidationError as error:
218208
logger.debug(f"Invalid X-SignedOperation fields: {error}")
219209
raise web.HTTPBadRequest(reason="Invalid X-SignedOperation fields") from error
@@ -244,9 +234,9 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME)
244234
signed_pubkey = get_signed_pubkey(request)
245235
signed_operation = get_signed_operation(request)
246236

247-
if signed_pubkey.content.domain != domain_name:
237+
if signed_operation.content.domain != domain_name:
248238
logger.debug(
249-
f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'"
239+
f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'"
250240
)
251241
raise web.HTTPUnauthorized(reason="Invalid domain")
252242

@@ -255,13 +245,11 @@ async def authenticate_jwk(request: web.Request, domain_name: str = DOMAIN_NAME)
255245
f"Invalid path '{signed_operation.content.path}' != '{request.path}'"
256246
)
257247
raise web.HTTPUnauthorized(reason="Invalid path")
258-
259248
if signed_operation.content.method != request.method:
260249
logger.debug(
261250
f"Invalid method '{signed_operation.content.method}' != '{request.method}'"
262251
)
263252
raise web.HTTPUnauthorized(reason="Invalid method")
264-
265253
return verify_signed_operation(signed_operation, signed_pubkey)
266254

267255

@@ -271,20 +259,17 @@ async def authenticate_websocket_message(
271259
"""Authenticate a websocket message since JS cannot configure headers on WebSockets."""
272260
signed_pubkey = SignedPubKeyHeader.parse_obj(message["X-SignedPubKey"])
273261
signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"])
274-
275-
if signed_pubkey.content.domain != domain_name:
262+
if signed_operation.content.domain != domain_name:
276263
logger.debug(
277264
f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'"
278265
)
279266
raise web.HTTPUnauthorized(reason="Invalid domain")
280-
281267
return verify_signed_operation(signed_operation, signed_pubkey)
282268

283269

284270
def require_jwk_authentication(
285271
handler: Callable[[web.Request, str], Coroutine[Any, Any, web.StreamResponse]]
286272
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
287-
288273
@functools.wraps(handler)
289274
async def wrapper(request):
290275
try:
@@ -296,6 +281,7 @@ async def wrapper(request):
296281
logging.exception(e)
297282
raise
298283

284+
# authenticated_sender is the authenticted wallet address of the requester (as a string)
299285
response = await handler(request, authenticated_sender)
300286
return response
301287

tests/unit/test_vm_client.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
from urllib.parse import urlparse
32

43
import aiohttp
@@ -173,7 +172,7 @@ async def websocket_handler(request):
173172

174173
app = web.Application()
175174
app.router.add_route(
176-
"GET", "/control/machine/{vm_id}/logs", websocket_handler
175+
"GET", "/control/machine/{vm_id}/stream_logs", websocket_handler
177176
) # Update route to match the URL
178177

179178
client = await aiohttp_client(app)
@@ -202,7 +201,9 @@ async def test_authenticate_jwk(aiohttp_client):
202201
vm_id = ItemHash("cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe")
203202

204203
async def test_authenticate_route(request):
205-
address = await authenticate_jwk(request, domain_name=urlparse(node_url).netloc)
204+
address = await authenticate_jwk(
205+
request, domain_name=urlparse(node_url).hostname
206+
)
206207
assert vm_client.account.get_address() == address
207208
return web.Response(text="ok")
208209

@@ -222,7 +223,7 @@ async def test_authenticate_route(request):
222223
)
223224

224225
status_code, response_text = await vm_client.stop_instance(vm_id)
225-
assert status_code == 200
226+
assert status_code == 200, response_text
226227
assert response_text == "ok"
227228

228229
await vm_client.session.close()
@@ -239,22 +240,19 @@ async def websocket_handler(request):
239240

240241
first_message = await ws.receive_json()
241242
credentials = first_message["auth"]
242-
address = await authenticate_websocket_message(
243-
{
244-
"X-SignedPubKey": json.loads(credentials["X-SignedPubKey"]),
245-
"X-SignedOperation": json.loads(credentials["X-SignedOperation"]),
246-
},
247-
domain_name=urlparse(node_url).netloc,
243+
sender_address = await authenticate_websocket_message(
244+
credentials,
245+
domain_name=urlparse(node_url).hostname,
248246
)
249247

250-
assert vm_client.account.get_address() == address
251-
await ws.send_str(address)
248+
assert vm_client.account.get_address() == sender_address
249+
await ws.send_str(sender_address)
252250

253251
return ws
254252

255253
app = web.Application()
256254
app.router.add_route(
257-
"GET", "/control/machine/{vm_id}/logs", websocket_handler
255+
"GET", "/control/machine/{vm_id}/stream_logs", websocket_handler
258256
) # Update route to match the URL
259257

260258
client = await aiohttp_client(app)
@@ -268,6 +266,7 @@ async def websocket_handler(request):
268266
)
269267

270268
valid = False
269+
271270
async for address in vm_client.get_logs(vm_id):
272271
assert address == vm_client.account.get_address()
273272
valid = True

0 commit comments

Comments
 (0)