Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,21 +199,31 @@ def resize_bounding_box():
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)


class TestKernelsCommon:
@pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name)
def test_scriptable(self, functional_info):
jit.script(functional_info.functional)

@pytest.mark.parametrize(
("functional_info", "sample_input"),
[
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(self, functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)

torch.testing.assert_close(eager, scripted)
@pytest.mark.parametrize(
"kernel",
[
pytest.param(kernel, id=name)
for name, kernel in F.__dict__.items()
if not name.startswith("_")
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
],
)
def test_scriptable(kernel):
jit.script(kernel)


@pytest.mark.parametrize(
("functional_info", "sample_input"),
[
pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}")
for functional_info in FUNCTIONAL_INFOS
for idx, sample_input in enumerate(functional_info.sample_inputs())
],
)
def test_eager_vs_scripted(functional_info, sample_input):
eager = functional_info(sample_input)
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)

torch.testing.assert_close(eager, scripted)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest.mock
from typing import Dict, Any, Tuple, cast
from typing import Dict, Any, Tuple

import numpy as np
import PIL.Image
Expand All @@ -22,4 +22,4 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor


def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return cast(torch.Tensor, one_hot(label, num_classes=num_categories))
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]