Skip to content

Commit 112accf

Browse files
authored
Cleanup prototype kernels for degenerate inputs (#6544)
* avoid double padding parsing * remove cloning in degenerate case * fix affine and rotate for degenerate inputs * fix rotate for degenerate inputs if expand=True
1 parent 84dcf69 commit 112accf

File tree

3 files changed

+31
-24
lines changed

3 files changed

+31
-24
lines changed

test/test_prototype_transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ class TestSmoke:
8686
transforms.RandomHorizontalFlip(),
8787
transforms.Pad(5),
8888
transforms.RandomZoomOut(),
89+
transforms.RandomRotation(degrees=(-45, 45)),
90+
transforms.RandomAffine(degrees=(-45, 45)),
8991
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
9092
# TODO: Something wrong with input data setup. Let's fix that
9193
# transforms.RandomEqualize(),
9294
# transforms.RandomInvert(),
9395
# transforms.RandomPosterize(bits=4),
9496
# transforms.RandomSolarize(threshold=0.5),
9597
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
96-
# transforms.RandomRotation(degrees=(-45, 45)),
97-
# transforms.RandomAffine(degrees=(-45, 45)),
9898
)
9999
def test_common(self, transform, input):
100100
transform(input)

test/test_prototype_transforms_functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def resize_segmentation_mask():
290290
@register_kernel_info_from_sample_inputs_fn
291291
def affine_image_tensor():
292292
for image, angle, translate, scale, shear in itertools.product(
293-
make_images(extra_dims=((), (4,))),
293+
make_images(),
294294
[-87, 15, 90], # angle
295295
[5, -5], # translate
296296
[0.77, 1.27], # scale
@@ -329,7 +329,7 @@ def affine_bounding_box():
329329
@register_kernel_info_from_sample_inputs_fn
330330
def affine_segmentation_mask():
331331
for mask, angle, translate, scale, shear in itertools.product(
332-
make_segmentation_masks(extra_dims=((), (4,)), num_objects=[10]),
332+
make_segmentation_masks(),
333333
[-87, 15, 90], # angle
334334
[5, -5], # translate
335335
[0.77, 1.27], # scale
@@ -347,7 +347,7 @@ def affine_segmentation_mask():
347347
@register_kernel_info_from_sample_inputs_fn
348348
def rotate_image_tensor():
349349
for image, angle, expand, center, fill in itertools.product(
350-
make_images(extra_dims=((), (4,))),
350+
make_images(),
351351
[-87, 15, 90], # angle
352352
[True, False], # expand
353353
[None, [12, 23]], # center
@@ -382,7 +382,7 @@ def rotate_bounding_box():
382382
@register_kernel_info_from_sample_inputs_fn
383383
def rotate_segmentation_mask():
384384
for mask, angle, expand, center in itertools.product(
385-
make_segmentation_masks(extra_dims=((), (4,)), num_objects=[10]),
385+
make_segmentation_masks(),
386386
[-87, 15, 90], # angle
387387
[True, False], # expand
388388
[None, [12, 23]], # center

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,14 @@ def resize_image_tensor(
108108
extra_dims = image.shape[:-3]
109109

110110
if image.numel() > 0:
111-
resized_image = _FT.resize(
111+
image = _FT.resize(
112112
image.view(-1, num_channels, old_height, old_width),
113113
size=[new_height, new_width],
114114
interpolation=interpolation.value,
115115
antialias=antialias,
116116
)
117-
else:
118-
# TODO: the cloning is probably unnecessary. Review this together with the other perf candidates
119-
resized_image = image.clone()
120117

121-
return resized_image.view(extra_dims + (num_channels, new_height, new_width))
118+
return image.view(extra_dims + (num_channels, new_height, new_width))
122119

123120

124121
def resize_image_pil(
@@ -229,6 +226,9 @@ def affine_image_tensor(
229226
fill: Optional[List[float]] = None,
230227
center: Optional[List[float]] = None,
231228
) -> torch.Tensor:
229+
if img.numel() == 0:
230+
return img
231+
232232
num_channels, height, width = img.shape[-3:]
233233
extra_dims = img.shape[:-3]
234234
img = img.view(-1, num_channels, height, width)
@@ -452,23 +452,32 @@ def rotate_image_tensor(
452452
) -> torch.Tensor:
453453
num_channels, height, width = img.shape[-3:]
454454
extra_dims = img.shape[:-3]
455-
img = img.view(-1, num_channels, height, width)
456455

457456
center_f = [0.0, 0.0]
458457
if center is not None:
459458
if expand:
460459
warnings.warn("The provided center argument has no effect on the result if expand is True")
461460
else:
462-
_, height, width = get_dimensions_image_tensor(img)
463461
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
464462
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
465463

466464
# due to current incoherence of rotation angle direction between affine and rotate implementations
467465
# we need to set -angle.
468466
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
469-
output = _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
470-
new_height, new_width = output.shape[-2:]
471-
return output.view(extra_dims + (num_channels, new_height, new_width))
467+
468+
if img.numel() > 0:
469+
img = _FT.rotate(
470+
img.view(-1, num_channels, height, width),
471+
matrix,
472+
interpolation=interpolation.value,
473+
expand=expand,
474+
fill=fill,
475+
)
476+
new_height, new_width = img.shape[-2:]
477+
else:
478+
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
479+
480+
return img.view(extra_dims + (num_channels, new_height, new_width))
472481

473482

474483
def rotate_image_pil(
@@ -557,19 +566,17 @@ def pad_image_tensor(
557566
num_channels, height, width = img.shape[-3:]
558567
extra_dims = img.shape[:-3]
559568

560-
left, right, top, bottom = _FT._parse_pad_padding(padding)
561-
new_height = height + top + bottom
562-
new_width = width + left + right
563-
564569
if img.numel() > 0:
565-
padded_image = _FT.pad(
570+
img = _FT.pad(
566571
img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
567572
)
573+
new_height, new_width = img.shape[-2:]
568574
else:
569-
# TODO: the cloning is probably unnecessary. Review this together with the other perf candidates
570-
padded_image = img.clone()
575+
left, right, top, bottom = _FT._parse_pad_padding(padding)
576+
new_height = height + top + bottom
577+
new_width = width + left + right
571578

572-
return padded_image.view(extra_dims + (num_channels, new_height, new_width))
579+
return img.view(extra_dims + (num_channels, new_height, new_width))
573580

574581

575582
# TODO: This should be removed once pytorch pad supports non-scalar padding values

0 commit comments

Comments
 (0)