diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 9917fea8218..8d7e4a1b68d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,7 +1,10 @@ +import inspect import math import os import re +from typing import get_type_hints + import numpy as np import PIL.Image import pytest @@ -314,6 +317,63 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on): spy.assert_called_once() + @pytest.mark.parametrize( + ("dispatcher_info", "feature_type", "kernel_info"), + [ + pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}") + for dispatcher_info in DISPATCHER_INFOS + for feature_type, kernel_info in dispatcher_info.kernel_infos.items() + ], + ) + def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info): + dispatcher_signature = inspect.signature(dispatcher_info.dispatcher) + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] + + kernel_signature = inspect.signature(kernel_info.kernel) + kernel_params = list(kernel_signature.parameters.values())[1:] + + # We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be + # explicit passed to the kernel. + feature_type_metadata = feature_type.__annotations__.keys() + kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata] + + dispatcher_params = iter(dispatcher_params) + for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): + try: + # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out + # dispatcher parameters that have no kernel equivalent while keeping the order intact. + while dispatcher_param.name != kernel_param.name: + dispatcher_param = next(dispatcher_params) + except StopIteration: + raise AssertionError( + f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` " + f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`." + ) from None + + assert dispatcher_param == kernel_param + + @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) + def test_dispatcher_feature_signatures_consistency(self, info): + try: + feature_method = getattr(features._Feature, info.id) + except AttributeError: + pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") + + dispatcher_signature = inspect.signature(info.dispatcher) + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] + + feature_signature = inspect.signature(feature_method) + feature_params = list(feature_signature.parameters.values())[1:] + + # Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, + # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively + # concrete dispatcher annotations. + feature_annotations = get_type_hints(feature_method) + for param in feature_params: + param._annotation = feature_annotations[param.name] + + assert dispatcher_params == feature_params + @pytest.mark.parametrize( ("alias", "target"),