Skip to content
60 changes: 60 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import inspect
import math
import os
import re

from typing import get_type_hints

import numpy as np
import PIL.Image
import pytest
Expand Down Expand Up @@ -314,6 +317,63 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on):

spy.assert_called_once()

@pytest.mark.parametrize(
("dispatcher_info", "feature_type", "kernel_info"),
[
pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}")
for dispatcher_info in DISPATCHER_INFOS
for feature_type, kernel_info in dispatcher_info.kernel_infos.items()
],
)
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info):
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

kernel_signature = inspect.signature(kernel_info.kernel)
kernel_params = list(kernel_signature.parameters.values())[1:]

# We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be
# explicit passed to the kernel.
feature_type_metadata = feature_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata]

dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
try:
# In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
# dispatcher parameters that have no kernel equivalent while keeping the order intact.
while dispatcher_param.name != kernel_param.name:
dispatcher_param = next(dispatcher_params)
except StopIteration:
raise AssertionError(
f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
) from None

assert dispatcher_param == kernel_param

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_feature_signatures_consistency(self, info):
try:
feature_method = getattr(features._Feature, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")

dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

feature_signature = inspect.signature(feature_method)
feature_params = list(feature_signature.parameters.values())[1:]

# Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined,
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
# concrete dispatcher annotations.
feature_annotations = get_type_hints(feature_method)
for param in feature_params:
param._annotation = feature_annotations[param.name]

assert dispatcher_params == feature_params


@pytest.mark.parametrize(
("alias", "target"),
Expand Down