|
| 1 | +import inspect |
1 | 2 | import math |
2 | 3 | import os |
3 | 4 | import re |
4 | 5 |
|
| 6 | +from typing import get_type_hints |
| 7 | + |
5 | 8 | import numpy as np |
6 | 9 | import PIL.Image |
7 | 10 | import pytest |
@@ -314,6 +317,63 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on): |
314 | 317 |
|
315 | 318 | spy.assert_called_once() |
316 | 319 |
|
| 320 | + @pytest.mark.parametrize( |
| 321 | + ("dispatcher_info", "feature_type", "kernel_info"), |
| 322 | + [ |
| 323 | + pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}") |
| 324 | + for dispatcher_info in DISPATCHER_INFOS |
| 325 | + for feature_type, kernel_info in dispatcher_info.kernel_infos.items() |
| 326 | + ], |
| 327 | + ) |
| 328 | + def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info): |
| 329 | + dispatcher_signature = inspect.signature(dispatcher_info.dispatcher) |
| 330 | + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] |
| 331 | + |
| 332 | + kernel_signature = inspect.signature(kernel_info.kernel) |
| 333 | + kernel_params = list(kernel_signature.parameters.values())[1:] |
| 334 | + |
| 335 | + # We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be |
| 336 | + # explicit passed to the kernel. |
| 337 | + feature_type_metadata = feature_type.__annotations__.keys() |
| 338 | + kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata] |
| 339 | + |
| 340 | + dispatcher_params = iter(dispatcher_params) |
| 341 | + for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): |
| 342 | + try: |
| 343 | + # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out |
| 344 | + # dispatcher parameters that have no kernel equivalent while keeping the order intact. |
| 345 | + while dispatcher_param.name != kernel_param.name: |
| 346 | + dispatcher_param = next(dispatcher_params) |
| 347 | + except StopIteration: |
| 348 | + raise AssertionError( |
| 349 | + f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` " |
| 350 | + f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`." |
| 351 | + ) from None |
| 352 | + |
| 353 | + assert dispatcher_param == kernel_param |
| 354 | + |
| 355 | + @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) |
| 356 | + def test_dispatcher_feature_signatures_consistency(self, info): |
| 357 | + try: |
| 358 | + feature_method = getattr(features._Feature, info.id) |
| 359 | + except AttributeError: |
| 360 | + pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") |
| 361 | + |
| 362 | + dispatcher_signature = inspect.signature(info.dispatcher) |
| 363 | + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] |
| 364 | + |
| 365 | + feature_signature = inspect.signature(feature_method) |
| 366 | + feature_params = list(feature_signature.parameters.values())[1:] |
| 367 | + |
| 368 | + # Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, |
| 369 | + # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively |
| 370 | + # concrete dispatcher annotations. |
| 371 | + feature_annotations = get_type_hints(feature_method) |
| 372 | + for param in feature_params: |
| 373 | + param._annotation = feature_annotations[param.name] |
| 374 | + |
| 375 | + assert dispatcher_params == feature_params |
| 376 | + |
317 | 377 |
|
318 | 378 | @pytest.mark.parametrize( |
319 | 379 | ("alias", "target"), |
|
0 commit comments