Skip to content

Commit 8bd2151

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] refactor Datapoint dispatch mechanism (#7747)
Summary: Co-authored-by: Nicolas Hug <[email protected]> Reviewed By: matteobettini Differential Revision: D48642281 fbshipit-source-id: 33a1dcba4bbc254a26ae091452a61609bb80f663
1 parent db56d55 commit 8bd2151

24 files changed

+1215
-1428
lines changed

test/common_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
829829
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))
830830

831831

832+
def make_video_tensor(*args, **kwargs):
833+
return make_video(*args, **kwargs).as_subclass(torch.Tensor)
834+
835+
832836
def make_video_loader(
833837
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
834838
*,

test/datasets_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ def test_transforms(self, config):
567567

568568
@test_all_configs
569569
def test_transforms_v2_wrapper(self, config):
570-
from torchvision.datapoints._datapoint import Datapoint
570+
from torchvision import datapoints
571571
from torchvision.datasets import wrap_dataset_for_transforms_v2
572572

573573
try:
@@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config):
588588
assert len(wrapped_dataset) == info["num_examples"]
589589

590590
wrapped_sample = wrapped_dataset[0]
591-
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
591+
assert tree_any(
592+
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
593+
)
592594
except TypeError as error:
593595
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
594596
if str(error).startswith(msg):

test/test_transforms_v2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,12 +1344,12 @@ def test_antialias_warning():
13441344
transforms.RandomResize(10, 20)(tensor_img)
13451345

13461346
with pytest.warns(UserWarning, match=match):
1347-
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))
1347+
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20))
13481348

13491349
with pytest.warns(UserWarning, match=match):
1350-
datapoints.Video(tensor_video).resize((20, 20))
1350+
F.resize(datapoints.Video(tensor_video), (20, 20))
13511351
with pytest.warns(UserWarning, match=match):
1352-
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))
1352+
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20))
13531353

13541354
with warnings.catch_warnings():
13551355
warnings.simplefilter("error")
@@ -1363,8 +1363,8 @@ def test_antialias_warning():
13631363
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
13641364
transforms.RandomResize(10, 20, antialias=True)(tensor_img)
13651365

1366-
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
1367-
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
1366+
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
1367+
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)
13681368

13691369

13701370
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))

test/test_transforms_v2_functional.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import math
33
import os
44
import re
5-
6-
from typing import get_type_hints
5+
from unittest import mock
76

87
import numpy as np
98
import PIL.Image
109
import pytest
11-
1210
import torch
1311

1412
from common_utils import (
@@ -27,6 +25,7 @@
2725
from torchvision.transforms.v2 import functional as F
2826
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
2927
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
28+
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
3029
from torchvision.transforms.v2.utils import is_simple_tensor
3130
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
3231
from transforms_v2_kernel_infos import KERNEL_INFOS
@@ -424,12 +423,18 @@ def test_pil_output_type(self, info, args_kwargs):
424423
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
425424
(datapoint, *other_args), kwargs = args_kwargs.load()
426425

427-
method_name = info.id
428-
method = getattr(datapoint, method_name)
429-
datapoint_type = type(datapoint)
430-
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
426+
input_type = type(datapoint)
427+
428+
wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]
431429

432-
info.dispatcher(datapoint, *other_args, **kwargs)
430+
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
431+
# proper kernel was wrapped
432+
if hasattr(wrapped_kernel, "__wrapped__"):
433+
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]
434+
435+
spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
436+
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
437+
info.dispatcher(datapoint, *other_args, **kwargs)
433438

434439
spy.assert_called_once()
435440

@@ -462,9 +467,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
462467
kernel_params = list(kernel_signature.parameters.values())[1:]
463468

464469
# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
465-
# explicit passed to the kernel.
466-
datapoint_type_metadata = datapoint_type.__annotations__.keys()
467-
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
470+
# explicitly passed to the kernel.
471+
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
472+
explicit_metadata = {
473+
datapoints.BoundingBoxes: {"format", "canvas_size"},
474+
}
475+
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]
468476

469477
dispatcher_params = iter(dispatcher_params)
470478
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
@@ -481,28 +489,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
481489

482490
assert dispatcher_param == kernel_param
483491

484-
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
485-
def test_dispatcher_datapoint_signatures_consistency(self, info):
486-
try:
487-
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
488-
except AttributeError:
489-
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")
490-
491-
dispatcher_signature = inspect.signature(info.dispatcher)
492-
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]
493-
494-
datapoint_signature = inspect.signature(datapoint_method)
495-
datapoint_params = list(datapoint_signature.parameters.values())[1:]
496-
497-
# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
498-
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
499-
# natively concrete dispatcher annotations.
500-
datapoint_annotations = get_type_hints(datapoint_method)
501-
for param in datapoint_params:
502-
param._annotation = datapoint_annotations[param.name]
503-
504-
assert dispatcher_params == datapoint_params
505-
506492
@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
507493
def test_unkown_type(self, info):
508494
unkown_input = object()

0 commit comments

Comments
 (0)