@@ -1400,12 +1400,16 @@ def five_crop(
14001400 inpt : ImageOrVideoTypeJIT , size : List [int ]
14011401) -> Tuple [ImageOrVideoTypeJIT , ImageOrVideoTypeJIT , ImageOrVideoTypeJIT , ImageOrVideoTypeJIT , ImageOrVideoTypeJIT ]:
14021402 # TODO: consider breaking BC here to return List[features.ImageTypeJIT/VideoTypeJIT] to align this op with `ten_crop`
1403- if isinstance (inpt , torch .Tensor ):
1404- output = five_crop_image_tensor (inpt , size )
1405- if not torch .jit .is_scripting () and isinstance (inpt , (features .Image , features .Video )):
1406- tmp = tuple (inpt .wrap_like (inpt , item ) for item in output ) # type: ignore[arg-type]
1407- output = tmp # type: ignore[assignment]
1408- return output
1403+ if isinstance (inpt , torch .Tensor ) and (
1404+ torch .jit .is_scripting () or not isinstance (inpt , (features .Image , features .Video ))
1405+ ):
1406+ return five_crop_image_tensor (inpt , size )
1407+ elif isinstance (inpt , features .Image ):
1408+ output = five_crop_image_tensor (inpt .as_subclass (torch .Tensor ), size )
1409+ return tuple (features .Image .wrap_like (inpt , item ) for item in output ) # type: ignore[return-value]
1410+ elif isinstance (inpt , features .Video ):
1411+ output = five_crop_video (inpt .as_subclass (torch .Tensor ), size )
1412+ return tuple (features .Video .wrap_like (inpt , item ) for item in output ) # type: ignore[return-value]
14091413 else : # isinstance(inpt, PIL.Image.Image):
14101414 return five_crop_image_pil (inpt , size )
14111415
@@ -1444,10 +1448,15 @@ def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = F
14441448def ten_crop (
14451449 inpt : Union [features .ImageTypeJIT , features .VideoTypeJIT ], size : List [int ], vertical_flip : bool = False
14461450) -> Union [List [features .ImageTypeJIT ], List [features .VideoTypeJIT ]]:
1447- if isinstance (inpt , torch .Tensor ):
1448- output = ten_crop_image_tensor (inpt , size , vertical_flip = vertical_flip )
1449- if not torch .jit .is_scripting () and isinstance (inpt , (features .Image , features .Video )):
1450- output = [inpt .wrap_like (inpt , item ) for item in output ] # type: ignore[arg-type]
1451- return output
1451+ if isinstance (inpt , torch .Tensor ) and (
1452+ torch .jit .is_scripting () or not isinstance (inpt , (features .Image , features .Video ))
1453+ ):
1454+ return ten_crop_image_tensor (inpt , size , vertical_flip = vertical_flip )
1455+ elif isinstance (inpt , features .Image ):
1456+ output = ten_crop_image_tensor (inpt .as_subclass (torch .Tensor ), size , vertical_flip = vertical_flip )
1457+ return [features .Image .wrap_like (inpt , item ) for item in output ]
1458+ elif isinstance (inpt , features .Video ):
1459+ output = ten_crop_video (inpt .as_subclass (torch .Tensor ), size , vertical_flip = vertical_flip )
1460+ return [features .Video .wrap_like (inpt , item ) for item in output ]
14521461 else : # isinstance(inpt, PIL.Image.Image):
14531462 return ten_crop_image_pil (inpt , size , vertical_flip = vertical_flip )
0 commit comments