1111from torchvision .transforms .functional import (
1212 _compute_resized_output_size as __compute_resized_output_size ,
1313 _get_inverse_affine_matrix ,
14+ _get_perspective_coeffs ,
1415 InterpolationMode ,
1516 pil_modes_mapping ,
1617 pil_to_tensor ,
@@ -906,12 +907,32 @@ def crop(inpt: features.InputTypeJIT, top: int, left: int, height: int, width: i
906907 return crop_image_pil (inpt , top , left , height , width )
907908
908909
910+ def _perspective_coefficients (
911+ startpoints : Optional [List [List [int ]]],
912+ endpoints : Optional [List [List [int ]]],
913+ coefficients : Optional [List [float ]],
914+ ) -> List [float ]:
915+ if coefficients is not None :
916+ if startpoints is not None and endpoints is not None :
917+ raise ValueError ("The startpoints/endpoints and the coefficients shouldn't be defined concurrently." )
918+ elif len (coefficients ) != 8 :
919+ raise ValueError ("Argument coefficients should have 8 float values" )
920+ return coefficients
921+ elif startpoints is not None and endpoints is not None :
922+ return _get_perspective_coeffs (startpoints , endpoints )
923+ else :
924+ raise ValueError ("Either the startpoints/endpoints or the coefficients must have non `None` values." )
925+
926+
909927def perspective_image_tensor (
910928 image : torch .Tensor ,
911- perspective_coeffs : List [float ],
929+ startpoints : Optional [List [List [int ]]],
930+ endpoints : Optional [List [List [int ]]],
912931 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
913932 fill : features .FillTypeJIT = None ,
933+ coefficients : Optional [List [float ]] = None ,
914934) -> torch .Tensor :
935+ perspective_coeffs = _perspective_coefficients (startpoints , endpoints , coefficients )
915936 if image .numel () == 0 :
916937 return image
917938
@@ -934,21 +955,24 @@ def perspective_image_tensor(
934955@torch .jit .unused
935956def perspective_image_pil (
936957 image : PIL .Image .Image ,
937- perspective_coeffs : List [float ],
958+ startpoints : Optional [List [List [int ]]],
959+ endpoints : Optional [List [List [int ]]],
938960 interpolation : InterpolationMode = InterpolationMode .BICUBIC ,
939961 fill : features .FillTypeJIT = None ,
962+ coefficients : Optional [List [float ]] = None ,
940963) -> PIL .Image .Image :
964+ perspective_coeffs = _perspective_coefficients (startpoints , endpoints , coefficients )
941965 return _FP .perspective (image , perspective_coeffs , interpolation = pil_modes_mapping [interpolation ], fill = fill )
942966
943967
944968def perspective_bounding_box (
945969 bounding_box : torch .Tensor ,
946970 format : features .BoundingBoxFormat ,
947- perspective_coeffs : List [float ],
971+ startpoints : Optional [List [List [int ]]],
972+ endpoints : Optional [List [List [int ]]],
973+ coefficients : Optional [List [float ]] = None ,
948974) -> torch .Tensor :
949-
950- if len (perspective_coeffs ) != 8 :
951- raise ValueError ("Argument perspective_coeffs should have 8 float values" )
975+ perspective_coeffs = _perspective_coefficients (startpoints , endpoints , coefficients )
952976
953977 original_shape = bounding_box .shape
954978 bounding_box = (
@@ -1029,8 +1053,10 @@ def perspective_bounding_box(
10291053
10301054def perspective_mask (
10311055 mask : torch .Tensor ,
1032- perspective_coeffs : List [float ],
1056+ startpoints : Optional [List [List [int ]]],
1057+ endpoints : Optional [List [List [int ]]],
10331058 fill : features .FillTypeJIT = None ,
1059+ coefficients : Optional [List [float ]] = None ,
10341060) -> torch .Tensor :
10351061 if mask .ndim < 3 :
10361062 mask = mask .unsqueeze (0 )
@@ -1039,7 +1065,7 @@ def perspective_mask(
10391065 needs_squeeze = False
10401066
10411067 output = perspective_image_tensor (
1042- mask , perspective_coeffs = perspective_coeffs , interpolation = InterpolationMode .NEAREST , fill = fill
1068+ mask , startpoints , endpoints , interpolation = InterpolationMode .NEAREST , fill = fill , coefficients = coefficients
10431069 )
10441070
10451071 if needs_squeeze :
@@ -1050,25 +1076,37 @@ def perspective_mask(
10501076
10511077def perspective_video (
10521078 video : torch .Tensor ,
1053- perspective_coeffs : List [float ],
1079+ startpoints : Optional [List [List [int ]]],
1080+ endpoints : Optional [List [List [int ]]],
10541081 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
10551082 fill : features .FillTypeJIT = None ,
1083+ coefficients : Optional [List [float ]] = None ,
10561084) -> torch .Tensor :
1057- return perspective_image_tensor (video , perspective_coeffs , interpolation = interpolation , fill = fill )
1085+ return perspective_image_tensor (
1086+ video , startpoints , endpoints , interpolation = interpolation , fill = fill , coefficients = coefficients
1087+ )
10581088
10591089
10601090def perspective (
10611091 inpt : features .InputTypeJIT ,
1062- perspective_coeffs : List [float ],
1092+ startpoints : Optional [List [List [int ]]],
1093+ endpoints : Optional [List [List [int ]]],
10631094 interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
10641095 fill : features .FillTypeJIT = None ,
1096+ coefficients : Optional [List [float ]] = None ,
10651097) -> features .InputTypeJIT :
10661098 if isinstance (inpt , torch .Tensor ) and (torch .jit .is_scripting () or not isinstance (inpt , features ._Feature )):
1067- return perspective_image_tensor (inpt , perspective_coeffs , interpolation = interpolation , fill = fill )
1099+ return perspective_image_tensor (
1100+ inpt , startpoints , endpoints , interpolation = interpolation , fill = fill , coefficients = coefficients
1101+ )
10681102 elif isinstance (inpt , features ._Feature ):
1069- return inpt .perspective (perspective_coeffs , interpolation = interpolation , fill = fill )
1103+ return inpt .perspective (
1104+ startpoints , endpoints , interpolation = interpolation , fill = fill , coefficients = coefficients
1105+ )
10701106 else :
1071- return perspective_image_pil (inpt , perspective_coeffs , interpolation = interpolation , fill = fill )
1107+ return perspective_image_pil (
1108+ inpt , startpoints , endpoints , interpolation = interpolation , fill = fill , coefficients = coefficients
1109+ )
10721110
10731111
10741112def elastic_image_tensor (
0 commit comments