Skip to content

Commit b179c66

Browse files
committed
Added GaussianBlur transform and tests
1 parent 615b175 commit b179c66

File tree

8 files changed

+142
-4
lines changed

8 files changed

+142
-4
lines changed

test/test_prototype_transforms.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,3 +644,59 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
644644
else:
645645
# vfdev-5: I do not know how to mock and test this case
646646
pass
647+
648+
649+
class TestGaussianBlur:
650+
def test_assertions(self):
651+
with pytest.raises(ValueError, match="Kernel size should be a tuple/list of two integers"):
652+
transforms.GaussianBlur([10, 12, 14])
653+
654+
with pytest.raises(ValueError, match="Kernel size value should be an odd and positive number"):
655+
transforms.GaussianBlur(4)
656+
657+
with pytest.raises(TypeError, match="sigma should be a single float or a list/tuple with length 2"):
658+
transforms.GaussianBlur(3, sigma=[1, 2, 3])
659+
660+
with pytest.raises(ValueError, match="If sigma is a single number, it must be positive"):
661+
transforms.GaussianBlur(3, sigma=-1.0)
662+
663+
with pytest.raises(ValueError, match="sigma values should be positive and of the form"):
664+
transforms.GaussianBlur(3, sigma=[2.0, 1.0])
665+
666+
@pytest.mark.parametrize("sigma", [10.0, [10.0, 12.0]])
667+
def test__get_params(self, sigma):
668+
transform = transforms.GaussianBlur(3, sigma=sigma)
669+
params = transform._get_params(None)
670+
671+
if isinstance(sigma, float):
672+
assert params["sigma"][0] == params["sigma"][1] == 10
673+
else:
674+
assert sigma[0] <= params["sigma"][0] <= sigma[1]
675+
assert sigma[0] <= params["sigma"][1] <= sigma[1]
676+
677+
@pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)])
678+
@pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]])
679+
def test__transform(self, kernel_size, sigma, mocker):
680+
transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
681+
682+
if isinstance(kernel_size, (tuple, list)):
683+
assert transform.kernel_size == kernel_size
684+
else:
685+
assert transform.kernel_size == (kernel_size, kernel_size)
686+
687+
if isinstance(sigma, (tuple, list)):
688+
assert transform.sigma == sigma
689+
else:
690+
assert transform.sigma == (sigma, sigma)
691+
692+
693+
fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur")
694+
inpt = features.Image(torch.rand(1, 3, 32, 32))
695+
# vfdev-5, Feature Request: let's store params as Transform attribute
696+
# This could be also helpful for users
697+
torch.manual_seed(12)
698+
_ = transform(inpt)
699+
torch.manual_seed(12)
700+
params = transform._get_params(inpt)
701+
702+
fn.assert_called_once_with(inpt, **params)

test/test_prototype_transforms_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def center_crop_bounding_box():
495495
)
496496

497497

498+
@register_kernel_info_from_sample_inputs_fn
498499
def center_crop_segmentation_mask():
499500
for mask, output_size in itertools.product(
500501
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
@@ -503,6 +504,16 @@ def center_crop_segmentation_mask():
503504
yield SampleInput(mask, output_size)
504505

505506

507+
@register_kernel_info_from_sample_inputs_fn
508+
def gaussian_blur_image_tensor():
509+
for image, kernel_size, sigma in itertools.product(
510+
make_images(extra_dims=((4,),)),
511+
[[3, 3], ],
512+
[None, [3.0, 3.0]],
513+
):
514+
yield SampleInput(image, kernel_size=kernel_size, sigma=sigma)
515+
516+
506517
@pytest.mark.parametrize(
507518
"kernel",
508519
[

torchvision/prototype/features/_feature.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,6 @@ def equalize(self) -> Any:
189189

190190
def invert(self) -> Any:
191191
return self
192+
193+
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Any:
194+
return self

torchvision/prototype/features/_image.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,9 @@ def invert(self) -> Image:
309309

310310
output = _F.invert_image_tensor(self)
311311
return Image.new_like(self, output)
312+
313+
def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Image:
314+
from torchvision.prototype.transforms import functional as _F
315+
316+
output = _F.gaussian_blur_image_tensor(self, kernel_size=kernel_size, sigma=sigma)
317+
return Image.new_like(self, output)

torchvision/prototype/transforms/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
RandomAffine,
2323
)
2424
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
25-
from ._misc import Identity, Normalize, ToDtype, Lambda
25+
from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda
2626
from ._type_conversion import DecodeImage, LabelToOneHot
2727

2828
from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip
29+
30+
# TODO: add RandomPerspective, RandomInvert, RandomPosterize, RandomSolarize,
31+
# RandomAdjustSharpness, RandomAutocontrast, ElasticTransform

torchvision/prototype/transforms/_misc.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import functools
2-
from typing import Any, List, Type, Callable, Dict
2+
from typing import Any, List, Type, Callable, Dict, Sequence, Union
33

44
import torch
5+
from torchvision.transforms.transforms import _setup_size
56
from torchvision.prototype.transforms import Transform, functional as F
67

78

@@ -46,6 +47,36 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
4647
return input
4748

4849

50+
class GaussianBlur(Transform):
51+
def __init__(
52+
self, kernel_size: Union[int, Sequence[int]], sigma: Union[float, Sequence[float]] = (0.1, 2.0)
53+
) -> None:
54+
super().__init__()
55+
self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
56+
for ks in self.kernel_size:
57+
if ks <= 0 or ks % 2 == 0:
58+
raise ValueError("Kernel size value should be an odd and positive number.")
59+
60+
if isinstance(sigma, float):
61+
if sigma <= 0:
62+
raise ValueError("If sigma is a single number, it must be positive.")
63+
sigma = (sigma, sigma)
64+
elif isinstance(sigma, Sequence) and len(sigma) == 2:
65+
if not 0.0 < sigma[0] <= sigma[1]:
66+
raise ValueError("sigma values should be positive and of the form (min, max).")
67+
else:
68+
raise TypeError("sigma should be a single float or a list/tuple with length 2 floats.")
69+
70+
self.sigma = sigma
71+
72+
def _get_params(self, sample: Any) -> Dict[str, Any]:
73+
sigma = torch.empty(1).uniform_(self.sigma[0], self.sigma[1]).item()
74+
return dict(sigma=[sigma, sigma])
75+
76+
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
77+
return F.gaussian_blur(inpt, **params)
78+
79+
4980
class ToDtype(Lambda):
5081
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
5182
self.dtype = dtype

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@
9999
ten_crop_image_tensor,
100100
ten_crop_image_pil,
101101
)
102-
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
102+
from ._misc import (
103+
normalize_image_tensor,
104+
gaussian_blur,
105+
gaussian_blur_image_tensor,
106+
gaussian_blur_image_pil,
107+
)
103108
from ._type_conversion import (
104109
decode_image_with_pil,
105110
decode_video_with_av,

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,28 @@
1-
from typing import Optional, List
1+
from typing import Optional, List, Union
22

33
import PIL.Image
44
import torch
5+
from torchvision.prototype import features
56
from torchvision.transforms import functional_tensor as _FT
67
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
78

89

10+
# shortcut type
11+
DType = Union[torch.Tensor, PIL.Image.Image, features._Feature]
12+
13+
914
normalize_image_tensor = _FT.normalize
1015

1116

17+
def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
18+
if isinstance(inpt, features.Image):
19+
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
20+
elif type(inpt) == torch.Tensor:
21+
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
22+
else:
23+
raise TypeError("Unsupported input type")
24+
25+
1226
def gaussian_blur_image_tensor(
1327
img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
1428
) -> torch.Tensor:
@@ -42,3 +56,12 @@ def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optio
4256
t_img = pil_to_tensor(img)
4357
output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma)
4458
return to_pil_image(output, mode=img.mode)
59+
60+
61+
def gaussian_blur(inpt: DType, kernel_size: List[int], sigma: Optional[List[float]] = None) -> DType:
62+
if isinstance(inpt, features._Feature):
63+
return inpt.gaussian_blur(kernel_size=kernel_size, sigma=sigma)
64+
elif isinstance(inpt, PIL.Image.Image):
65+
return gaussian_blur_image_pil(inpt, kernel_size=kernel_size, sigma=sigma)
66+
else:
67+
return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma)

0 commit comments

Comments
 (0)