Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/prototype_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(

def _process_inputs(self, actual, expected, *, id, allow_subclasses):
actual, expected = [
to_image_tensor(input) if not isinstance(input, torch.Tensor) else input for input in [actual, expected]
to_image_tensor(input) if not isinstance(input, torch.Tensor) else features.Image(input)
for input in [actual, expected]
]
return super()._process_inputs(actual, expected, id=id, allow_subclasses=allow_subclasses)

Expand Down
120 changes: 114 additions & 6 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import inspect
import random
from importlib.machinery import SourceFileLoader
from pathlib import Path

Expand All @@ -16,6 +17,7 @@
make_image,
make_images,
make_label,
make_segmentation_mask,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
Expand Down Expand Up @@ -852,10 +854,12 @@ def test_aa(self, inpt, interpolation):
assert_equal(expected_output, output)


# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()
def import_transforms_from_references(reference):
ref_det_filepath = Path(__file__).parent.parent / "references" / reference / "transforms.py"
return SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()


det_transforms = import_transforms_from_references("detection")


class TestRefDetTransforms:
Expand All @@ -873,7 +877,7 @@ def make_datapoints(self, with_mask=True):

yield (pil_image, target)

tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
tensor_image = torch.Tensor(make_image(size=size, color_space=features.ColorSpace.RGB))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this before. Let's use the utilities everywhere.

target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand All @@ -883,7 +887,7 @@ def make_datapoints(self, with_mask=True):

yield (tensor_image, target)

feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
Expand Down Expand Up @@ -927,3 +931,107 @@ def test_transform(self, t_ref, t, data_kwargs):
expected_output = t_ref(*dp)

assert_equal(expected_output, output)


seg_transforms = import_transforms_from_references("segmentation")


class TestSegDetTransforms:
def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8):
size = (256, 640)
num_categories = 21

conv_fns = []
if supports_pil:
conv_fns.append(to_image_pil)
conv_fns.extend([torch.Tensor, lambda x: x])

for conv_fn in conv_fns:
feature_image = make_image(size=size, color_space=features.ColorSpace.RGB, dtype=image_dtype)
feature_mask = make_segmentation_mask(size=size, num_categories=num_categories, dtype=torch.uint8)

dp = (conv_fn(feature_image), feature_mask)
dp_ref = (
to_image_pil(feature_image) if supports_pil else torch.Tensor(feature_image),
to_image_pil(feature_mask),
)

yield dp, dp_ref

def set_seed(self, seed):
torch.manual_seed(seed)
random.seed(seed)

def check(self, t, t_ref, data_kwargs=None):
for dp, dp_ref in self.make_datapoints(**data_kwargs or dict()):

self.set_seed(12)
output = t(dp)

self.set_seed(12)
expected_output = t_ref(*dp_ref)

assert_equal(output, expected_output)

@pytest.mark.parametrize(
("t_ref", "t", "data_kwargs"),
[
(
seg_transforms.RandomHorizontalFlip(flip_prob=1.0),
prototype_transforms.RandomHorizontalFlip(p=1.0),
dict(),
),
(
seg_transforms.RandomHorizontalFlip(flip_prob=0.0),
prototype_transforms.RandomHorizontalFlip(p=0.0),
dict(),
),
# (
# seg_transforms.RandomCrop(size=480),
# prototype_transforms.RandomCrop(
# size=480, pad_if_needed=True, fill=defaultdict(lambda: 0, {features.Mask: 255})
# ),
# dict(),
# ),
(
seg_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
prototype_transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
dict(supports_pil=False, image_dtype=torch.float),
),
],
)
def test_common(self, t_ref, t, data_kwargs):
self.check(t, t_ref, data_kwargs)

def test_random_resize_train(self, mocker):
base_size = 520
min_size = base_size // 2
max_size = base_size * 2

randint = torch.randint

def patched_randint(a, b, *other_args, **kwargs):
if kwargs or len(other_args) > 1 or other_args[0] != ():
return randint(a, b, *other_args, **kwargs)

return random.randint(a, b)

t = prototype_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True)
mocker.patch(
"torchvision.prototype.transforms._geometry.torch.randint",
new=patched_randint,
)

t_ref = det_transforms.RandomResize(min_size=min_size, max_size=max_size)

self.check(t, t_ref)

def test_random_resize_eval(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both this and the test above currently fail. The resized mask is off by quite a bit and so far I don't have an idea why or if that is expected. For the references the resize is performed as PIL image while the prototype transforms uses a features.Mask which internally is treated like a tensor image.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens because PIL and PyTorch don't agree on how exactly the values should be picked during nearest interpolation in case the upsampling is not an integer multiple. For example

from torchvision.prototype.transforms import functional as F
import torch


num_categories = 21
size = [520]

torch.manual_seed(123)
segmentation_mask = torch.randint(0, num_categories, (384, 640), dtype=torch.uint8)
print(segmentation_mask[:3, :3])

a = F.resize_mask(segmentation_mask, size)
print(a[:5, :5])

b = F.to_image_tensor(
    F.resize_image_pil(
        F.to_image_pil(segmentation_mask),
        size,
        interpolation=F.InterpolationMode.NEAREST,
    )
).squeeze(0)
print(b[:5, :5])

assert a.shape == b.shape
tensor([[16,  1,  0, 12],
        [ 3,  7, 20, 15],
        [12, 19, 11, 12],
        [ 2, 20,  6,  9]], dtype=torch.uint8)
tensor([[16, 16,  1,  0,  0],
        [16, 16,  1,  0,  0],
        [ 3,  3,  7, 20, 20],
        [12, 12, 19, 11, 11],
        [12, 12, 19, 11, 11]], dtype=torch.uint8)
tensor([[16,  1,  1,  0, 12],
        [ 3,  7,  7, 20, 15],
        [ 3,  7,  7, 20, 15],
        [12, 19, 19, 11, 12],
        [ 2, 20, 20,  6,  9]], dtype=torch.uint8)

This might have an effect on the accuracy. So far I don't see a way to change this other than aligning torch.nn.functional.interpolate with what PIL does for nearest interpolation.

torch.manual_seed(0)
base_size = 520

t = prototype_transforms.Resize(size=base_size, antialias=True)

t_ref = det_transforms.RandomResize(min_size=base_size, max_size=base_size)

self.check(t, t_ref)