Skip to content

Commit e3f7baa

Browse files
pmeierdatumbox
andauthored
add test for dispatcher kernel signature consistency (#6904)
* add test for dispatcher kernel signature consistency * add dispatcher feature signature consistency test * fix error message Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 4f3a000 commit e3f7baa

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import inspect
12
import math
23
import os
34
import re
45

6+
from typing import get_type_hints
7+
58
import numpy as np
69
import PIL.Image
710
import pytest
@@ -314,6 +317,63 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on):
314317

315318
spy.assert_called_once()
316319

320+
@pytest.mark.parametrize(
321+
("dispatcher_info", "feature_type", "kernel_info"),
322+
[
323+
pytest.param(dispatcher_info, feature_type, kernel_info, id=f"{dispatcher_info.id}-{feature_type.__name__}")
324+
for dispatcher_info in DISPATCHER_INFOS
325+
for feature_type, kernel_info in dispatcher_info.kernel_infos.items()
326+
],
327+
)
328+
def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, feature_type, kernel_info):
329+
dispatcher_signature = inspect.signature(dispatcher_info.dispatcher)
330+
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
331+
332+
kernel_signature = inspect.signature(kernel_info.kernel)
333+
kernel_params = list(kernel_signature.parameters.values())[1:]
334+
335+
# We filter out metadata that is implicitly passed to the dispatcher through the input feature, but has to be
336+
# explicit passed to the kernel.
337+
feature_type_metadata = feature_type.__annotations__.keys()
338+
kernel_params = [param for param in kernel_params if param.name not in feature_type_metadata]
339+
340+
dispatcher_params = iter(dispatcher_params)
341+
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
342+
try:
343+
# In general, the dispatcher parameters are a superset of the kernel parameters. Thus, we filter out
344+
# dispatcher parameters that have no kernel equivalent while keeping the order intact.
345+
while dispatcher_param.name != kernel_param.name:
346+
dispatcher_param = next(dispatcher_params)
347+
except StopIteration:
348+
raise AssertionError(
349+
f"Parameter `{kernel_param.name}` of kernel `{kernel_info.id}` "
350+
f"has no corresponding parameter on the dispatcher `{dispatcher_info.id}`."
351+
) from None
352+
353+
assert dispatcher_param == kernel_param
354+
355+
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
356+
def test_dispatcher_feature_signatures_consistency(self, info):
357+
try:
358+
feature_method = getattr(features._Feature, info.id)
359+
except AttributeError:
360+
pytest.skip("Dispatcher doesn't support arbitrary feature dispatch.")
361+
362+
dispatcher_signature = inspect.signature(info.dispatcher)
363+
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
364+
365+
feature_signature = inspect.signature(feature_method)
366+
feature_params = list(feature_signature.parameters.values())[1:]
367+
368+
# Because we use `from __future__ import annotations` inside the module where `features._Feature` is defined,
369+
# the annotations are stored as strings. This makes them concrete again, so they can be compared to the natively
370+
# concrete dispatcher annotations.
371+
feature_annotations = get_type_hints(feature_method)
372+
for param in feature_params:
373+
param._annotation = feature_annotations[param.name]
374+
375+
assert dispatcher_params == feature_params
376+
317377

318378
@pytest.mark.parametrize(
319379
("alias", "target"),

0 commit comments

Comments
 (0)