Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d9e1379
[PoC] refactor Datapoint dispatch mechanism
pmeier Jul 19, 2023
36b9d36
fix test
pmeier Jul 19, 2023
f36c64c
Merge branch 'main' into kernel-registration
pmeier Jul 26, 2023
bbaa35c
add dispatch to adjust_brightness
pmeier Jul 27, 2023
ca4ad32
enforce no register overwrite
pmeier Jul 27, 2023
d23a80e
[PoC] make wrapping interal kernel more convenient
pmeier Jul 27, 2023
bf47188
[PoC] enforce explicit no-ops
pmeier Jul 27, 2023
74d5054
fix adjust_brightness tests and remove methods
pmeier Jul 27, 2023
e88be5e
Merge branch 'main' into kernel-registration
pmeier Jul 27, 2023
f178373
address minor comments
pmeier Jul 27, 2023
65e80d0
make no-op registration a decorator
pmeier Jul 28, 2023
9614477
Merge branch 'main'
pmeier Aug 1, 2023
6ac08e4
explicit metadata
pmeier Aug 1, 2023
cac079b
implement dispatchers for erase five/ten_crop and temporal_subsample
pmeier Aug 1, 2023
c7256b4
make shape getters proper dispatchers
pmeier Aug 1, 2023
bf78cd6
fix
pmeier Aug 1, 2023
f86f89b
port normalize and to_dtype
pmeier Aug 2, 2023
d90daf6
address comments
pmeier Aug 2, 2023
09eec9a
address comments and cleanup
pmeier Aug 2, 2023
3730811
more cleanup
pmeier Aug 2, 2023
7203453
Merge branch 'main' into kernel-registration
pmeier Aug 2, 2023
31bee5f
port all remaining dispatchers to the new mechanism
pmeier Jul 28, 2023
a924013
put back legacy test_dispatch_datapoint
pmeier Aug 2, 2023
b3c2c88
minor test fixes
pmeier Aug 2, 2023
a1f5ea4
Update torchvision/transforms/v2/functional/_utils.py
pmeier Aug 2, 2023
d29d95b
reinstante antialias tests
pmeier Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,10 @@ def make_video(size=DEFAULT_SIZE, *, num_frames=3, batch_dims=(), **kwargs):
return datapoints.Video(make_image(size, batch_dims=(*batch_dims, num_frames), **kwargs))


def make_video_tensor(*args, **kwargs):
return make_video(*args, **kwargs).as_subclass(torch.Tensor)


def make_video_loader(
size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
*,
Expand Down
6 changes: 4 additions & 2 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def test_transforms(self, config):

@test_all_configs
def test_transforms_v2_wrapper(self, config):
from torchvision.datapoints._datapoint import Datapoint
from torchvision import datapoints
from torchvision.datasets import wrap_dataset_for_transforms_v2

try:
Expand All @@ -588,7 +588,9 @@ def test_transforms_v2_wrapper(self, config):
assert len(wrapped_dataset) == info["num_examples"]

wrapped_sample = wrapped_dataset[0]
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
assert tree_any(
lambda item: isinstance(item, (datapoints.Datapoint, PIL.Image.Image)), wrapped_sample
)
except TypeError as error:
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"
if str(error).startswith(msg):
Expand Down
10 changes: 5 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,12 +1344,12 @@ def test_antialias_warning():
transforms.RandomResize(10, 20)(tensor_img)

with pytest.warns(UserWarning, match=match):
datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20))
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20))

with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resize((20, 20))
F.resize(datapoints.Video(tensor_video), (20, 20))
with pytest.warns(UserWarning, match=match):
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20))
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20))

with warnings.catch_warnings():
warnings.simplefilter("error")
Expand All @@ -1363,8 +1363,8 @@ def test_antialias_warning():
transforms.RandomShortestSize((20, 20), antialias=True)(tensor_img)
transforms.RandomResize(10, 20, antialias=True)(tensor_img)

datapoints.Image(tensor_img).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
datapoints.Video(tensor_video).resized_crop(0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Image(tensor_img), 0, 0, 10, 10, (20, 20), antialias=True)
F.resized_crop(datapoints.Video(tensor_video), 0, 0, 10, 10, (20, 20), antialias=True)


@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
Expand Down
52 changes: 19 additions & 33 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import math
import os
import re

from typing import get_type_hints
from unittest import mock

import numpy as np
import PIL.Image
import pytest

import torch

from common_utils import (
Expand All @@ -27,6 +25,7 @@
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding
from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes
from torchvision.transforms.v2.functional._utils import _KERNEL_REGISTRY
from torchvision.transforms.v2.utils import is_simple_tensor
from transforms_v2_dispatcher_infos import DISPATCHER_INFOS
from transforms_v2_kernel_infos import KERNEL_INFOS
Expand Down Expand Up @@ -424,12 +423,18 @@ def test_pil_output_type(self, info, args_kwargs):
def test_dispatch_datapoint(self, info, args_kwargs, spy_on):
(datapoint, *other_args), kwargs = args_kwargs.load()

method_name = info.id
method = getattr(datapoint, method_name)
datapoint_type = type(datapoint)
spy = spy_on(method, module=datapoint_type.__module__, name=f"{datapoint_type.__name__}.{method_name}")
input_type = type(datapoint)

wrapped_kernel = _KERNEL_REGISTRY[info.dispatcher][input_type]

info.dispatcher(datapoint, *other_args, **kwargs)
# In case the wrapper was decorated with @functools.wraps, we can make the check more strict and test if the
# proper kernel was wrapped
if hasattr(wrapped_kernel, "__wrapped__"):
assert wrapped_kernel.__wrapped__ is info.kernels[input_type]

spy = mock.MagicMock(wraps=wrapped_kernel, name=wrapped_kernel.__name__)
with mock.patch.dict(_KERNEL_REGISTRY[info.dispatcher], values={input_type: spy}):
info.dispatcher(datapoint, *other_args, **kwargs)

spy.assert_called_once()

Expand Down Expand Up @@ -462,9 +467,12 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi
kernel_params = list(kernel_signature.parameters.values())[1:]

# We filter out metadata that is implicitly passed to the dispatcher through the input datapoint, but has to be
# explicit passed to the kernel.
datapoint_type_metadata = datapoint_type.__annotations__.keys()
kernel_params = [param for param in kernel_params if param.name not in datapoint_type_metadata]
# explicitly passed to the kernel.
input_type = {v: k for k, v in dispatcher_info.kernels.items()}.get(kernel_info.kernel)
explicit_metadata = {
datapoints.BoundingBoxes: {"format", "canvas_size"},
}
kernel_params = [param for param in kernel_params if param.name not in explicit_metadata.get(input_type, set())]

dispatcher_params = iter(dispatcher_params)
for dispatcher_param, kernel_param in zip(dispatcher_params, kernel_params):
Expand All @@ -481,28 +489,6 @@ def test_dispatcher_kernel_signatures_consistency(self, dispatcher_info, datapoi

assert dispatcher_param == kernel_param

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_dispatcher_datapoint_signatures_consistency(self, info):
try:
datapoint_method = getattr(datapoints._datapoint.Datapoint, info.id)
except AttributeError:
pytest.skip("Dispatcher doesn't support arbitrary datapoint dispatch.")

dispatcher_signature = inspect.signature(info.dispatcher)
dispatcher_params = list(dispatcher_signature.parameters.values())[1:]

datapoint_signature = inspect.signature(datapoint_method)
datapoint_params = list(datapoint_signature.parameters.values())[1:]

# Because we use `from __future__ import annotations` inside the module where `datapoints._datapoint` is
# defined, the annotations are stored as strings. This makes them concrete again, so they can be compared to the
# natively concrete dispatcher annotations.
datapoint_annotations = get_type_hints(datapoint_method)
for param in datapoint_params:
param._annotation = datapoint_annotations[param.name]

assert dispatcher_params == datapoint_params

@pytest.mark.parametrize("info", DISPATCHER_INFOS, ids=lambda info: info.id)
def test_unkown_type(self, info):
unkown_input = object()
Expand Down
Loading