Skip to content

Commit 3e97c1c

Browse files
committed
Added non-scalar fill support workaround for pad
1 parent 6a94f17 commit 3e97c1c

File tree

4 files changed

+65
-6
lines changed

4 files changed

+65
-6
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,17 @@ def resized_crop_segmentation_mask():
426426
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
427427

428428

429+
@register_kernel_info_from_sample_inputs_fn
430+
def pad_image_tensor():
431+
for image, padding, fill, padding_mode in itertools.product(
432+
make_images(),
433+
[[1], [1, 1], [1, 1, 2, 2]], # padding
434+
[12], # fill
435+
["constant", "symmetric", "edge", "reflect"], # padding mode,
436+
):
437+
yield SampleInput(image, padding=padding, fill=fill, padding_mode=padding_mode)
438+
439+
429440
@register_kernel_info_from_sample_inputs_fn
430441
def pad_segmentation_mask():
431442
for mask, padding, padding_mode in itertools.product(

torchvision/prototype/features/_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def resized_crop(
136136
# How dangerous to do this instead of raising an error ?
137137
return self
138138

139-
def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any:
139+
def pad(self, padding: List[int], fill: Union[float, Sequence[float]] = 0, padding_mode: str = "constant") -> Any:
140140
# Just output itself
141141
# How dangerous to do this instead of raising an error ?
142142
return self

torchvision/prototype/features/_image.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,17 @@ def resized_crop(
163163
)
164164
return Image.new_like(self, output)
165165

166-
def pad(self, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Image:
166+
def pad(self, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant") -> Image:
167167
from torchvision.prototype.transforms import functional as _F
168168

169-
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
169+
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
170+
if isinstance(fill, (int, float)):
171+
output = _F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode)
172+
else:
173+
from torchvision.prototype.transforms.functional._geometry import _pad_with_vector_fill
174+
175+
output = _pad_with_vector_fill(self, padding, fill=fill, padding_mode=padding_mode)
176+
170177
return Image.new_like(self, output)
171178

172179
def rotate(

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,10 +503,45 @@ def rotate(
503503
return inpt
504504

505505

506-
pad_image_tensor = _FT.pad
507506
pad_image_pil = _FP.pad
508507

509508

509+
def pad_image_tensor(
510+
img: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant"
511+
) -> torch.Tensor:
512+
num_masks, height, width = img.shape[-3:]
513+
extra_dims = img.shape[:-3]
514+
515+
padded_image = _FT.pad(
516+
img=img.view(-1, num_masks, height, width), padding=padding, fill=fill, padding_mode=padding_mode
517+
)
518+
519+
new_height, new_width = padded_image.shape[-2:]
520+
return padded_image.view(extra_dims + (num_masks, new_height, new_width))
521+
522+
523+
# TODO: This should be removed once pytorch pad supports non-scalar padding values
524+
def _pad_with_vector_fill(
525+
img: torch.Tensor, padding: List[int], fill: Union[float, List[float]] = 0.0, padding_mode: str = "constant"
526+
):
527+
if padding_mode != "constant":
528+
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
529+
530+
output = pad_image_tensor(img, padding, fill=0, padding_mode="constant")
531+
left, top, right, bottom = padding
532+
fill = torch.tensor(fill, dtype=img.dtype, device=img.device).view(-1, 1, 1)
533+
534+
if top > 0:
535+
output[..., :top, :] = fill
536+
if left > 0:
537+
output[..., :, :left] = fill
538+
if bottom > 0:
539+
output[..., -bottom:, :] = fill
540+
if right > 0:
541+
output[..., :, -right:] = fill
542+
return output
543+
544+
510545
def pad_segmentation_mask(
511546
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
512547
) -> torch.Tensor:
@@ -537,13 +572,19 @@ def pad_bounding_box(
537572
return bounding_box
538573

539574

540-
def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any:
575+
def pad(
576+
inpt: Any, padding: List[int], fill: Union[float, Sequence[float]] = 0.0, padding_mode: str = "constant"
577+
) -> Any:
541578
if isinstance(inpt, features._Feature):
542579
return inpt.pad(padding, fill=fill, padding_mode=padding_mode)
543580
elif isinstance(inpt, PIL.Image.Image):
544581
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
545582
elif isinstance(inpt, torch.Tensor):
546-
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
583+
# PyTorch's pad supports only scalars on fill. So we need to overwrite the colour
584+
if isinstance(fill, (int, float)):
585+
return pad_image_tensor(inpt, padding, fill=fill, padding_mode=padding_mode)
586+
else:
587+
return _pad_with_vector_fill(inpt, padding, fill=fill, padding_mode=padding_mode)
547588
else:
548589
return inpt
549590

0 commit comments

Comments
 (0)