Skip to content

Commit 598e3f0

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] unwrap features in dispatchers (#6831)
Summary: * unwrap features in dispatchers * cleanup * align erase / five_crop / ten_crop with other dispatchers Reviewed By: YosuaMichael Differential Revision: D40722913 fbshipit-source-id: eab1c060d1a9de2a587535b62f223be49ce52792
1 parent d05f513 commit 598e3f0

File tree

3 files changed

+37
-22
lines changed

3 files changed

+37
-22
lines changed

torchvision/prototype/transforms/functional/_augment.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@ def erase(
3434
v: torch.Tensor,
3535
inplace: bool = False,
3636
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
37-
if isinstance(inpt, torch.Tensor):
38-
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
39-
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
40-
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
41-
return output
37+
if isinstance(inpt, torch.Tensor) and (
38+
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
39+
):
40+
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
41+
elif isinstance(inpt, features.Image):
42+
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
43+
return features.Image.wrap_like(inpt, output)
44+
elif isinstance(inpt, features.Video):
45+
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
46+
return features.Video.wrap_like(inpt, output)
4247
else: # isinstance(inpt, PIL.Image.Image):
4348
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)

torchvision/prototype/transforms/functional/_deprecated.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
2525
def rgb_to_grayscale(
2626
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
2727
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
28-
old_color_space = (
29-
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
30-
if isinstance(inpt, torch.Tensor)
31-
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
32-
else None
33-
)
28+
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
29+
inpt = inpt.as_subclass(torch.Tensor)
30+
old_color_space = None
31+
elif isinstance(inpt, torch.Tensor):
32+
old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
33+
else:
34+
old_color_space = None
3435

3536
call = ", num_output_channels=3" if num_output_channels == 3 else ""
3637
replacement = (

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,12 +1400,16 @@ def five_crop(
14001400
inpt: ImageOrVideoTypeJIT, size: List[int]
14011401
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
14021402
# TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
1403-
if isinstance(inpt, torch.Tensor):
1404-
output = five_crop_image_tensor(inpt, size)
1405-
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1406-
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
1407-
output = tmp # type: ignore[assignment]
1408-
return output
1403+
if isinstance(inpt, torch.Tensor) and (
1404+
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
1405+
):
1406+
return five_crop_image_tensor(inpt, size)
1407+
elif isinstance(inpt, features.Image):
1408+
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
1409+
return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
1410+
elif isinstance(inpt, features.Video):
1411+
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
1412+
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
14091413
else: # isinstance(inpt, PIL.Image.Image):
14101414
return five_crop_image_pil(inpt, size)
14111415

@@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
14441448
def ten_crop(
14451449
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
14461450
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
1447-
if isinstance(inpt, torch.Tensor):
1448-
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
1449-
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
1450-
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
1451-
return output
1451+
if isinstance(inpt, torch.Tensor) and (
1452+
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
1453+
):
1454+
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
1455+
elif isinstance(inpt, features.Image):
1456+
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
1457+
return [features.Image.wrap_like(inpt, item) for item in output]
1458+
elif isinstance(inpt, features.Video):
1459+
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
1460+
return [features.Video.wrap_like(inpt, item) for item in output]
14521461
else: # isinstance(inpt, PIL.Image.Image):
14531462
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)

0 commit comments

Comments
 (0)