Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions torchvision/prototype/transforms/functional/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ def erase(
v: torch.Tensor,
inplace: bool = False,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor):
output = erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
elif isinstance(inpt, features.Image):
output = erase_image_tensor(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Image.wrap_like(inpt, output)
elif isinstance(inpt, features.Video):
output = erase_video(inpt.as_subclass(torch.Tensor), i=i, j=j, h=h, w=w, v=v, inplace=inplace)
return features.Video.wrap_like(inpt, output)
else: # isinstance(inpt, PIL.Image.Image):
return erase_image_pil(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace)
13 changes: 7 additions & 6 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
def rgb_to_grayscale(
inpt: Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT], num_output_channels: int = 1
) -> Union[features.LegacyImageTypeJIT, features.LegacyVideoTypeJIT]:
old_color_space = (
features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
if isinstance(inpt, torch.Tensor)
and (torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)))
else None
)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
inpt = inpt.as_subclass(torch.Tensor)
old_color_space = None
elif isinstance(inpt, torch.Tensor):
old_color_space = features._image._from_tensor_shape(inpt.shape) # type: ignore[arg-type]
else:
old_color_space = None

call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
Expand Down
31 changes: 20 additions & 11 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,12 +1400,16 @@ def five_crop(
inpt: ImageOrVideoTypeJIT, size: List[int]
) -> Tuple[ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT, ImageOrVideoTypeJIT]:
# TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
tmp = tuple(inpt.wrap_like(inpt, item) for item in output) # type: ignore[arg-type]
output = tmp # type: ignore[assignment]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, features.Image):
output = five_crop_image_tensor(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Image.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
elif isinstance(inpt, features.Video):
output = five_crop_video(inpt.as_subclass(torch.Tensor), size)
return tuple(features.Video.wrap_like(inpt, item) for item in output) # type: ignore[return-value]
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)

Expand Down Expand Up @@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
def ten_crop(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[features.ImageTypeJIT], List[features.VideoTypeJIT]]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if not torch.jit.is_scripting() and isinstance(inpt, (features.Image, features.Video)):
output = [inpt.wrap_like(inpt, item) for item in output] # type: ignore[arg-type]
return output
if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
):
return ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
elif isinstance(inpt, features.Image):
output = ten_crop_image_tensor(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Image.wrap_like(inpt, item) for item in output]
elif isinstance(inpt, features.Video):
output = ten_crop_video(inpt.as_subclass(torch.Tensor), size, vertical_flip=vertical_flip)
return [features.Video.wrap_like(inpt, item) for item in output]
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)