From 906bdbf511bcc98a4d610089199d3d48a61dce1f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Nov 2022 23:42:11 +0100 Subject: [PATCH 1/3] add test for dispatcher kernel signature consistency --- test/test_prototype_transforms_functional.py | 36 ++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 20f5e5330ff..405cf5ed507 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1,3 +1,4 @@ +import inspect import math import os import re @@ -314,6 +315,41 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on): spy.assert_called_once() + @pytest.mark.parametrize( + ("dispatcher_info", "feature_type", "kernel_info"), + [ + pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}") + for dispatcher_info in DISPATCHER_INFOS + for feature_type, kernel_info in dispatcher_info.kernel_infos.items() + ], + ) + def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info): + dispatcher_signature = inspect.signature(dispatcher_info.dispatcher) + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] + + kernel_signature = inspect.signature(kernel_info.kernel) + kernel_params = list(kernel_signature.parameters.values())[1:] + + # We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be + # explicit passed to the kernel. + feature_type_metadata = feature_type.__annotations__.keys() + kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata] + + dispatcher_params = iter(dispatcher_params) + for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params): + try: + # In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out + # dispatcher parameters that have no kernel equivalent while keeping the order intact. + while dispatcher_param.name != kernel_param.name: + dispatcher_param = next(dispatcher_params) + except StopIteration: + raise AssertionError( + f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` " + f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`." + ) from None + + assert dispatcher_param == kernel_param + @pytest.mark.parametrize( ("alias", "target"), From 96d8bffcb7381c3216a327a8afeeee7776823eac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Nov 2022 14:59:29 +0100 Subject: [PATCH 2/3] add dispatcher feature signature consistency test --- test/test_prototype_transforms_functional.py | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index abc5f3b9f6f..7280f3e13be 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -3,6 +3,8 @@ import os import re +from typing import get_type_hints + import numpy as np import PIL.Image import pytest @@ -350,6 +352,28 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature assert dispatcher_param == kernel_param + @pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id) + def test_dispatcher_feature_signatures_consistency(self, info): + try: + feature_method = getattr(features._Feature, info.id) + except AttributeError: + pytest.skip("foo") + + dispatcher_signature = inspect.signature(info.dispatcher) + dispatcher_params = list(dispatcher_signature.parameters.values())[1:] + + feature_signature = inspect.signature(feature_method) + feature_params = list(feature_signature.parameters.values())[1:] + + # Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined, + # the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively + # concrete dispatcher annotations. + feature_annotations = get_type_hints(feature_method) + for param in feature_params: + param._annotation = feature_annotations[param.name] + + assert dispatcher_params == feature_params + @pytest.mark.parametrize( ("alias", "target"), From 2f4cf3d88af6f80ac296f504c7939ed67641a65f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Nov 2022 15:22:07 +0100 Subject: [PATCH 3/3] fix error message --- test/test_prototype_transforms_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7280f3e13be..8d7e4a1b68d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -357,7 +357,7 @@ def test_dispatcher_feature_signatures_consistency(self, info): try: feature_method = getattr(features._Feature, info.id) except AttributeError: - pytest.skip("foo") + pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.") dispatcher_signature = inspect.signature(info.dispatcher) dispatcher_params = list(dispatcher_signature.parameters.values())[1:]