Skip to content

Commit 31bee5f

Browse files
committed
port all remaining dispatchers to the new mechanism
1 parent 7203453 commit 31bee5f

File tree

17 files changed

+644
-1093
lines changed

17 files changed

+644
-1093
lines changed

test/datasets_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_transforms(self, config):
567567

568568
@test_all_configs
569569
def test_transforms_v2_wrapper(self, config):
570-
from torchvision.datapoints._datapoint import Datapoint
570+
from torchvision import datapoints
571571
from torchvision.datasets import wrap_dataset_for_transforms_v2
572572

573573
try:
@@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config):
588588
assert len(wrapped_dataset) == info["num_examples"]
589589

590590
wrapped_sample = wrapped_dataset[0]
591-
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
591+
assert tree_any(
592+
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
593+
)
592594
except TypeError as error:
593595
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
594596
if str(error).startswith(msg):

test/test_transforms_v2_functional.py

Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import os
44
import re
55

6-
from typing import get_type_hints
7-
86
import numpy as np
97
import PIL.Image
108
import pytest
@@ -417,22 +415,6 @@ def test_pil_output_type(self, info, args_kwargs):
417415

418416
assert isinstance(output, PIL.Image.Image)
419417

420-
@make_info_args_kwargs_parametrization(
421-
DISPATCHER_INFOS,
422-
args_kwargs_fn=lambda info: info.sample_inputs(),
423-
)
424-
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
425-
(datapoint, *other_args), kwargs = args_kwargs.load()
426-
427-
method_name = info.id
428-
method = getattr(datapoint, method_name)
429-
datapoint_type = type(datapoint)
430-
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
431-
432-
info.dispatcher(datapoint, *other_args, **kwargs)
433-
434-
spy.assert_called_once()
435-
436418
@make_info_args_kwargs_parametrization(
437419
DISPATCHER_INFOS,
438420
args_kwargs_fn=lambda info: info.sample_inputs(),
@@ -462,9 +444,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
462444
kernel_params = list(kernel_signature.parameters.values())[1:]
463445

464446
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
465-
# explicit passed to the kernel.
466-
datapoint_type_metadata = datapoint_type.__annotations__.keys()
467-
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
447+
# explicitly passed to the kernel.
448+
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
449+
explicit_metadata = {
450+
datapoints.BoundingBoxes: {"format", "canvas_size"},
451+
}
452+
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
468453

469454
dispatcher_params = iter(dispatcher_params)
470455
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
@@ -481,28 +466,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
481466

482467
assert dispatcher_param == kernel_param
483468

484-
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
485-
def test_dispatcher_datapoint_signatures_consistency(self, info):
486-
try:
487-
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
488-
except AttributeError:
489-
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
490-
491-
dispatcher_signature = inspect.signature(info.dispatcher)
492-
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
493-
494-
datapoint_signature = inspect.signature(datapoint_method)
495-
datapoint_params = list(datapoint_signature.parameters.values())[1:]
496-
497-
# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
498-
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
499-
# natively concrete dispatcher annotations.
500-
datapoint_annotations = get_type_hints(datapoint_method)
501-
for param in datapoint_params:
502-
param._annotation = datapoint_annotations[param.name]
503-
504-
assert dispatcher_params == datapoint_params
505-
506469
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
507470
def test_unkown_type(self, info):
508471
unkown_input = object()

test/test_transforms_v2_refactored.py

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import inspect
44
import math
55
import re
6-
from typing import get_type_hints
76
from unittest import mock
87

98
import numpy as np
@@ -178,28 +177,28 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs):
178177
"""Checks if the dispatcher correctly dispatches the input to the corresponding kernel and that the input type is
179178
preserved in doing so. For bounding boxes also checks that the format is preserved.
180179
"""
181-
if isinstance(input, datapoints._datapoint.Datapoint):
182-
if dispatcher in {F.resize, F.adjust_brightness}:
180+
input_type = type(input)
181+
182+
if isinstance(input, datapoints.Datapoint):
183+
wrapped_kernel = _KERNEL_REGISTRY[dispatcher][input_type]
184+
185+
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
186+
# proper kernel was wrapped
187+
if hasattr(wrapped_kernel, "__wrapped__"):
188+
assert wrapped_kernel.__wrapped__ is kernel
189+
190+
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
191+
with mock.patch.dict(_KERNEL_REGISTRY[dispatcher], values={input_type: spy}):
183192
output = dispatcher(input, *args, **kwargs)
184-
else:
185-
# Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly,
186-
# but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel.
187-
spy = mock.MagicMock(wraps=kernel, name=kernel.__name__)
188-
with mock.patch.object(F, kernel.__name__, spy):
189-
# Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class.
190-
# Since that is not the case here, we need to prefix f"_{cls.__name__}"
191-
# See https://docs.python.org/3/tutorial/classes.html#private-variables for details
192-
with mock.patch.object(datapoints._datapoint.Datapoint, "_Datapoint__F", new=F):
193-
output = dispatcher(input, *args, **kwargs)
194193

195-
spy.assert_called_once()
194+
spy.assert_called_once()
196195
else:
197196
with mock.patch(f"{dispatcher.__module__}.{kernel.__name__}", wraps=kernel) as spy:
198197
output = dispatcher(input, *args, **kwargs)
199198

200199
spy.assert_called_once()
201200

202-
assert isinstance(output, type(input))
201+
assert isinstance(output, input_type)
203202

204203
if isinstance(input, datapoints.BoundingBoxes):
205204
assert output.format == input.format
@@ -214,34 +213,32 @@ def check_dispatcher(
214213
check_dispatch=True,
215214
**kwargs,
216215
):
216+
unknown_input = object()
217217
with mock.patch("torch._C._log_api_usage_once", wraps=torch._C._log_api_usage_once) as spy:
218-
dispatcher(input, *args, **kwargs)
218+
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
219+
dispatcher(unknown_input, *args, **kwargs)
219220

220221
spy.assert_any_call(f"{dispatcher.__module__}.{dispatcher.__name__}")
221222

222-
unknown_input = object()
223-
with pytest.raises(TypeError, match=re.escape(str(type(unknown_input)))):
224-
dispatcher(unknown_input, *args, **kwargs)
225-
226223
if check_scripted_smoke:
227224
_check_dispatcher_scripted_smoke(dispatcher, input, *args, **kwargs)
228225

229226
if check_dispatch:
230227
_check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs)
231228

232229

233-
def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
230+
def check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
234231
"""Checks if the signature of the dispatcher matches the kernel signature."""
235-
dispatcher_signature = inspect.signature(dispatcher)
236-
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
237-
238-
kernel_signature = inspect.signature(kernel)
239-
kernel_params = list(kernel_signature.parameters.values())[1:]
232+
dispatcher_params = list(inspect.signature(dispatcher).parameters.values())[1:]
233+
kernel_params = list(inspect.signature(kernel).parameters.values())[1:]
240234

241-
if issubclass(input_type, datapoints._datapoint.Datapoint):
235+
if issubclass(input_type, datapoints.Datapoint):
242236
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
243237
# explicitly passed to the kernel.
244-
kernel_params = [param for param in kernel_params if param.name not in input_type.__annotations__.keys()]
238+
explicit_metadata = {
239+
datapoints.BoundingBoxes: {"format", "canvas_size"},
240+
}
241+
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
245242

246243
dispatcher_params = iter(dispatcher_params)
247244
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
@@ -264,32 +261,6 @@ def _check_dispatcher_kernel_signature_match(dispatcher, *, kernel, input_type):
264261
assert dispatcher_param == kernel_param
265262

266263

267-
def _check_dispatcher_datapoint_signature_match(dispatcher):
268-
"""Checks if the signature of the dispatcher matches the corresponding method signature on the Datapoint class."""
269-
if dispatcher in {F.resize, F.adjust_brightness}:
270-
return
271-
dispatcher_signature = inspect.signature(dispatcher)
272-
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
273-
274-
datapoint_method = getattr(datapoints._datapoint.Datapoint, dispatcher.__name__)
275-
datapoint_signature = inspect.signature(datapoint_method)
276-
datapoint_params = list(datapoint_signature.parameters.values())[1:]
277-
278-
# Some annotations in the `datapoints._datapoint` module
279-
# are stored as strings. The block below makes them concrete again (non-strings), so they can be compared to the
280-
# natively concrete dispatcher annotations.
281-
datapoint_annotations = get_type_hints(datapoint_method)
282-
for param in datapoint_params:
283-
param._annotation = datapoint_annotations[param.name]
284-
285-
assert dispatcher_params == datapoint_params
286-
287-
288-
def check_dispatcher_signatures_match(dispatcher, *, kernel, input_type):
289-
_check_dispatcher_kernel_signature_match(dispatcher, kernel=kernel, input_type=input_type)
290-
_check_dispatcher_datapoint_signature_match(dispatcher)
291-
292-
293264
def _check_transform_v1_compatibility(transform, input):
294265
"""If the transform defines the ``_v1_transform_cls`` attribute, checks if the transform has a public, static
295266
``get_params`` method, is scriptable, and the scripted version can be called without error."""
@@ -461,7 +432,7 @@ def test_exhaustive_kernel_registration(dispatcher, registered_datapoint_clss):
461432
*[f"- {name}" for name in names],
462433
"",
463434
f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).",
464-
f"If not, register explicit no-ops with @_register_explicit_noops({', '.join(names)})",
435+
f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})",
465436
]
466437
)
467438
)
@@ -602,7 +573,7 @@ def test_dispatcher(self, size, kernel, make_input):
602573
],
603574
)
604575
def test_dispatcher_signature(self, kernel, input_type):
605-
check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type)
576+
check_dispatcher_kernel_signature_match(F.resize, kernel=kernel, input_type=input_type)
606577

607578
@pytest.mark.parametrize("size", OUTPUT_SIZES)
608579
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -800,7 +771,7 @@ def test_noop(self, size, make_input):
800771

801772
# This identity check is not a requirement. It is here to avoid breaking the behavior by accident. If there
802773
# is a good reason to break this, feel free to downgrade to an equality check.
803-
if isinstance(input, datapoints._datapoint.Datapoint):
774+
if isinstance(input, datapoints.Datapoint):
804775
# We can't test identity directly, since that checks for the identity of the Python object. Since all
805776
# datapoints unwrap before a kernel and wrap again afterwards, the Python object changes. Thus, we check
806777
# that the underlying storage is the same
@@ -884,7 +855,7 @@ def test_dispatcher(self, kernel, make_input):
884855
],
885856
)
886857
def test_dispatcher_signature(self, kernel, input_type):
887-
check_dispatcher_signatures_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
858+
check_dispatcher_kernel_signature_match(F.horizontal_flip, kernel=kernel, input_type=input_type)
888859

889860
@pytest.mark.parametrize(
890861
"make_input",
@@ -1067,7 +1038,7 @@ def test_dispatcher(self, kernel, make_input):
10671038
],
10681039
)
10691040
def test_dispatcher_signature(self, kernel, input_type):
1070-
check_dispatcher_signatures_match(F.affine, kernel=kernel, input_type=input_type)
1041+
check_dispatcher_kernel_signature_match(F.affine, kernel=kernel, input_type=input_type)
10711042

10721043
@pytest.mark.parametrize(
10731044
"make_input",
@@ -1363,7 +1334,7 @@ def test_dispatcher(self, kernel, make_input):
13631334
],
13641335
)
13651336
def test_dispatcher_signature(self, kernel, input_type):
1366-
check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type)
1337+
check_dispatcher_kernel_signature_match(F.vertical_flip, kernel=kernel, input_type=input_type)
13671338

13681339
@pytest.mark.parametrize(
13691340
"make_input",
@@ -1520,7 +1491,7 @@ def test_dispatcher(self, kernel, make_input):
15201491
],
15211492
)
15221493
def test_dispatcher_signature(self, kernel, input_type):
1523-
check_dispatcher_signatures_match(F.rotate, kernel=kernel, input_type=input_type)
1494+
check_dispatcher_kernel_signature_match(F.rotate, kernel=kernel, input_type=input_type)
15241495

15251496
@pytest.mark.parametrize(
15261497
"make_input",
@@ -1971,7 +1942,7 @@ def test_dispatcher(self, kernel, make_input):
19711942
],
19721943
)
19731944
def test_dispatcher_signature(self, kernel, input_type):
1974-
check_dispatcher_signatures_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
1945+
check_dispatcher_kernel_signature_match(F.adjust_brightness, kernel=kernel, input_type=input_type)
19751946

19761947
@pytest.mark.parametrize("brightness_factor", _CORRECTNESS_BRIGHTNESS_FACTORS)
19771948
def test_image_correctness(self, brightness_factor):

test/transforms_v2_dispatcher_infos.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,15 @@ def sample_inputs(self, *datapoint_types, filter_metadata=True):
6969
import itertools
7070

7171
for args_kwargs in sample_inputs:
72-
for name in itertools.chain(
73-
datapoint_type.__annotations__.keys(),
74-
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
75-
# per-dispatcher level. However, so far there is no option for that.
76-
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
77-
):
78-
if name in args_kwargs.kwargs:
79-
del args_kwargs.kwargs[name]
72+
if hasattr(datapoint_type, "__annotations__"):
73+
for name in itertools.chain(
74+
datapoint_type.__annotations__.keys(),
75+
# FIXME: this seems ok for conversion dispatchers, but we should probably handle this on a
76+
# per-dispatcher level. However, so far there is no option for that.
77+
(f"old_{name}" for name in datapoint_type.__annotations__.keys()),
78+
):
79+
if name in args_kwargs.kwargs:
80+
del args_kwargs.kwargs[name]
8081

8182
yield args_kwargs
8283

torchvision/datapoints/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from torchvision import _BETA_TRANSFORMS_WARNING, _WARN_ABOUT_BETA_TRANSFORMS
22

33
from ._bounding_box import BoundingBoxes, BoundingBoxFormat
4-
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT
4+
from ._datapoint import _FillType, _FillTypeJIT, _InputType, _InputTypeJIT, Datapoint
55
from ._image import _ImageType, _ImageTypeJIT, _TensorImageType, _TensorImageTypeJIT, Image
66
from ._mask import Mask
77
from ._video import _TensorVideoType, _TensorVideoTypeJIT, _VideoType, _VideoTypeJIT, Video

0 commit comments

Comments
 (0)