@@ -627,66 +627,77 @@ def __repr__(self):
627627 return self .__class__ .__name__ + '(p={})' .format (self .p )
628628
629629
630- class RandomPerspective (object ):
631- """Performs Perspective transformation of the given PIL Image randomly with a given probability.
630+ class RandomPerspective (torch .nn .Module ):
631+ """Performs a random perspective transformation of the given image with a given probability.
632+ The image can be a PIL Image or a Tensor, in which case it is expected
633+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
632634
633635 Args:
634- interpolation : Default- Image.BICUBIC
635-
636- p (float): probability of the image being perspectively transformed. Default value is 0.5
637-
638- distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
636+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
637+ Default is 0.5.
638+ p (float): probability of the image being transformed. Default is 0.5.
639+ interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and
640+ ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors.
641+ fill (n-tuple or int or float): Pixel fill value for area outside the rotated
642+ image. If int or float, the value is used for all bands respectively. Default is 0.
643+ This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor
644+ input. Fill value for the area outside the transform in the output image is always 0.
639645
640- fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
641- If int, it is used for all channels respectively. Default value is 0.
642646 """
643647
644- def __init__ (self , distortion_scale = 0.5 , p = 0.5 , interpolation = Image .BICUBIC , fill = 0 ):
648+ def __init__ (self , distortion_scale = 0.5 , p = 0.5 , interpolation = Image .BILINEAR , fill = 0 ):
649+ super ().__init__ ()
645650 self .p = p
646651 self .interpolation = interpolation
647652 self .distortion_scale = distortion_scale
648653 self .fill = fill
649654
650- def __call__ (self , img ):
655+ def forward (self , img ):
651656 """
652657 Args:
653- img (PIL Image): Image to be Perspectively transformed.
658+ img (PIL Image or Tensor ): Image to be Perspectively transformed.
654659
655660 Returns:
656- PIL Image: Random perspectivley transformed image.
661+ PIL Image or Tensor: Randomly transformed image.
657662 """
658- if not F ._is_pil_image (img ):
659- raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
660-
661- if random .random () < self .p :
662- width , height = img .size
663+ if torch .rand (1 ) < self .p :
664+ width , height = F ._get_image_size (img )
663665 startpoints , endpoints = self .get_params (width , height , self .distortion_scale )
664666 return F .perspective (img , startpoints , endpoints , self .interpolation , self .fill )
665667 return img
666668
667669 @staticmethod
668- def get_params (width , height , distortion_scale ) :
670+ def get_params (width : int , height : int , distortion_scale : float ) -> Tuple [ List [ List [ int ]], List [ List [ int ]]] :
669671 """Get parameters for ``perspective`` for a random perspective transform.
670672
671673 Args:
672- width : width of the image.
673- height : height of the image.
674+ width (int): width of the image.
675+ height (int): height of the image.
676+ distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
674677
675678 Returns:
676679 List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
677680 List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
678681 """
679- half_height = int (height / 2 )
680- half_width = int (width / 2 )
681- topleft = (random .randint (0 , int (distortion_scale * half_width )),
682- random .randint (0 , int (distortion_scale * half_height )))
683- topright = (random .randint (width - int (distortion_scale * half_width ) - 1 , width - 1 ),
684- random .randint (0 , int (distortion_scale * half_height )))
685- botright = (random .randint (width - int (distortion_scale * half_width ) - 1 , width - 1 ),
686- random .randint (height - int (distortion_scale * half_height ) - 1 , height - 1 ))
687- botleft = (random .randint (0 , int (distortion_scale * half_width )),
688- random .randint (height - int (distortion_scale * half_height ) - 1 , height - 1 ))
689- startpoints = [(0 , 0 ), (width - 1 , 0 ), (width - 1 , height - 1 ), (0 , height - 1 )]
682+ half_height = height // 2
683+ half_width = width // 2
684+ topleft = [
685+ int (torch .randint (0 , int (distortion_scale * half_width ) + 1 , size = (1 , )).item ()),
686+ int (torch .randint (0 , int (distortion_scale * half_height ) + 1 , size = (1 , )).item ())
687+ ]
688+ topright = [
689+ int (torch .randint (width - int (distortion_scale * half_width ) - 1 , width , size = (1 , )).item ()),
690+ int (torch .randint (0 , int (distortion_scale * half_height ) + 1 , size = (1 , )).item ())
691+ ]
692+ botright = [
693+ int (torch .randint (width - int (distortion_scale * half_width ) - 1 , width , size = (1 , )).item ()),
694+ int (torch .randint (height - int (distortion_scale * half_height ) - 1 , height , size = (1 , )).item ())
695+ ]
696+ botleft = [
697+ int (torch .randint (0 , int (distortion_scale * half_width ) + 1 , size = (1 , )).item ()),
698+ int (torch .randint (height - int (distortion_scale * half_height ) - 1 , height , size = (1 , )).item ())
699+ ]
700+ startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
690701 endpoints = [topleft , topright , botright , botleft ]
691702 return startpoints , endpoints
692703
0 commit comments