Skip to content

Commit c69d2f1

Browse files
Yosua Michael Maranathafacebook-github-bot
authored andcommitted
[fbsync] add a signature consistency tests for v1 vs. v2 dispatchers (#6914)
Summary: * add a signature consistency tests for v1 vs. v2 dispatchers * temporarily increase test verbosity * Revert "temporarily increase test verbosity" This reverts commit 468c73f. * fix test to allow annotation deviations * fill <-> center for rotate * ignore annotation changes for center / translate in rotate / affine Reviewed By: NicolasHug Differential Revision: D41265182 fbshipit-source-id: 141c0a89a7579578386eeb159bb9e89e43c8394e
1 parent 18ee274 commit c69d2f1

File tree

8 files changed

+92
-13
lines changed

8 files changed

+92
-13
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from torchvision import transforms as legacy_transforms
2626
from torchvision._utils import sequence_to_str
2727
from torchvision.prototype import features, transforms as prototype_transforms
28-
from torchvision.prototype.transforms import functional as F
28+
from torchvision.prototype.transforms import functional as prototype_F
2929
from torchvision.prototype.transforms._utils import query_spatial_size
3030
from torchvision.prototype.transforms.functional import to_image_pil
31+
from torchvision.transforms import functional as legacy_F
3132

3233
DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=[features.ColorSpace.RGB], extra_dims=[(4,)])
3334

@@ -985,7 +986,7 @@ def _transform(self, inpt, params):
985986
return inpt
986987

987988
fill = self.fill[type(inpt)]
988-
return F.pad(inpt, padding=params["padding"], fill=fill)
989+
return prototype_F.pad(inpt, padding=params["padding"], fill=fill)
989990

990991

991992
class TestRefSegTransforms:
@@ -1119,3 +1120,81 @@ def test_random_resize_eval(self, mocker):
11191120
t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size)
11201121

11211122
self.check_resize(mocker, t_ref, t)
1123+
1124+
1125+
@pytest.mark.parametrize(
1126+
("legacy_dispatcher", "name_only_params"),
1127+
[
1128+
(legacy_F.get_dimensions, {}),
1129+
(legacy_F.get_image_size, {}),
1130+
(legacy_F.get_image_num_channels, {}),
1131+
(legacy_F.to_tensor, {}),
1132+
(legacy_F.pil_to_tensor, {}),
1133+
(legacy_F.convert_image_dtype, {}),
1134+
(legacy_F.to_pil_image, {}),
1135+
(legacy_F.normalize, {}),
1136+
(legacy_F.resize, {}),
1137+
(legacy_F.pad, {"padding", "fill"}),
1138+
(legacy_F.crop, {}),
1139+
(legacy_F.center_crop, {}),
1140+
(legacy_F.resized_crop, {}),
1141+
(legacy_F.hflip, {}),
1142+
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
1143+
(legacy_F.vflip, {}),
1144+
(legacy_F.five_crop, {}),
1145+
(legacy_F.ten_crop, {}),
1146+
(legacy_F.adjust_brightness, {}),
1147+
(legacy_F.adjust_contrast, {}),
1148+
(legacy_F.adjust_saturation, {}),
1149+
(legacy_F.adjust_hue, {}),
1150+
(legacy_F.adjust_gamma, {}),
1151+
(legacy_F.rotate, {"center", "fill"}),
1152+
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
1153+
(legacy_F.to_grayscale, {}),
1154+
(legacy_F.rgb_to_grayscale, {}),
1155+
(legacy_F.to_tensor, {}),
1156+
(legacy_F.erase, {}),
1157+
(legacy_F.gaussian_blur, {}),
1158+
(legacy_F.invert, {}),
1159+
(legacy_F.posterize, {}),
1160+
(legacy_F.solarize, {}),
1161+
(legacy_F.adjust_sharpness, {}),
1162+
(legacy_F.autocontrast, {}),
1163+
(legacy_F.equalize, {}),
1164+
(legacy_F.elastic_transform, {"fill"}),
1165+
],
1166+
)
1167+
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
1168+
legacy_signature = inspect.signature(legacy_dispatcher)
1169+
legacy_params = list(legacy_signature.parameters.values())[1:]
1170+
1171+
try:
1172+
prototype_dispatcher = getattr(prototype_F, legacy_dispatcher.__name__)
1173+
except AttributeError:
1174+
raise AssertionError(
1175+
f"Legacy dispatcher `F.{legacy_dispatcher.__name__}` has no prototype equivalent"
1176+
) from None
1177+
1178+
prototype_signature = inspect.signature(prototype_dispatcher)
1179+
prototype_params = list(prototype_signature.parameters.values())[1:]
1180+
1181+
# Some dispatchers got extra parameters. This makes sure they have a default argument and thus are BC. We don't
1182+
# need to check if parameters were added in the middle rather than at the end, since that will be caught by the
1183+
# regular check below.
1184+
prototype_params, new_prototype_params = (
1185+
prototype_params[: len(legacy_params)],
1186+
prototype_params[len(legacy_params) :],
1187+
)
1188+
for param in new_prototype_params:
1189+
assert param.default is not param.empty
1190+
1191+
# Some annotations were changed mostly to supersets of what was there before. Plus, some legacy dispatchers had no
1192+
# annotations. In these cases we simply drop the annotation and default argument from the comparison
1193+
for prototype_param, legacy_param in zip(prototype_params, legacy_params):
1194+
if legacy_param.name in name_only_params:
1195+
prototype_param._annotation = prototype_param._default = inspect.Parameter.empty
1196+
legacy_param._annotation = legacy_param._default = inspect.Parameter.empty
1197+
elif legacy_param.annotation is inspect.Parameter.empty:
1198+
prototype_param._annotation = inspect.Parameter.empty
1199+
1200+
assert prototype_params == legacy_params

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def rotate(
132132
angle: float,
133133
interpolation: InterpolationMode = InterpolationMode.NEAREST,
134134
expand: bool = False,
135-
fill: FillTypeJIT = None,
136135
center: Optional[List[float]] = None,
136+
fill: FillTypeJIT = None,
137137
) -> BoundingBox:
138138
output, spatial_size = self._F.rotate_bounding_box(
139139
self.as_subclass(torch.Tensor),

torchvision/prototype/features/_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ def rotate(
199199
angle: float,
200200
interpolation: InterpolationMode = InterpolationMode.NEAREST,
201201
expand: bool = False,
202-
fill: FillTypeJIT = None,
203202
center: Optional[List[float]] = None,
203+
fill: FillTypeJIT = None,
204204
) -> _Feature:
205205
return self
206206

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def rotate(
174174
angle: float,
175175
interpolation: InterpolationMode = InterpolationMode.NEAREST,
176176
expand: bool = False,
177-
fill: FillTypeJIT = None,
178177
center: Optional[List[float]] = None,
178+
fill: FillTypeJIT = None,
179179
) -> Image:
180180
output = self._F.rotate_image_tensor(
181181
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center

torchvision/prototype/features/_mask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def rotate(
8989
angle: float,
9090
interpolation: InterpolationMode = InterpolationMode.NEAREST,
9191
expand: bool = False,
92-
fill: FillTypeJIT = None,
9392
center: Optional[List[float]] = None,
93+
fill: FillTypeJIT = None,
9494
) -> Mask:
9595
output = self._F.rotate_mask(self.as_subclass(torch.Tensor), angle, expand=expand, center=center, fill=fill)
9696
return Mask.wrap_like(self, output)

torchvision/prototype/features/_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def rotate(
134134
angle: float,
135135
interpolation: InterpolationMode = InterpolationMode.NEAREST,
136136
expand: bool = False,
137-
fill: FillTypeJIT = None,
138137
center: Optional[List[float]] = None,
138+
fill: FillTypeJIT = None,
139139
) -> Video:
140140
output = self._F.rotate_video(
141141
self.as_subclass(torch.Tensor), angle, interpolation=interpolation, expand=expand, fill=fill, center=center

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
305305
**params,
306306
interpolation=self.interpolation,
307307
expand=self.expand,
308-
fill=fill,
309308
center=self.center,
309+
fill=fill,
310310
)
311311

312312

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -521,8 +521,8 @@ def rotate_image_tensor(
521521
angle: float,
522522
interpolation: InterpolationMode = InterpolationMode.NEAREST,
523523
expand: bool = False,
524-
fill: features.FillTypeJIT = None,
525524
center: Optional[List[float]] = None,
525+
fill: features.FillTypeJIT = None,
526526
) -> torch.Tensor:
527527
shape = image.shape
528528
num_channels, height, width = shape[-3:]
@@ -560,8 +560,8 @@ def rotate_image_pil(
560560
angle: float,
561561
interpolation: InterpolationMode = InterpolationMode.NEAREST,
562562
expand: bool = False,
563-
fill: features.FillTypeJIT = None,
564563
center: Optional[List[float]] = None,
564+
fill: features.FillTypeJIT = None,
565565
) -> PIL.Image.Image:
566566
if center is not None and expand:
567567
warnings.warn("The provided center argument has no effect on the result if expand is True")
@@ -612,8 +612,8 @@ def rotate_mask(
612612
mask: torch.Tensor,
613613
angle: float,
614614
expand: bool = False,
615-
fill: features.FillTypeJIT = None,
616615
center: Optional[List[float]] = None,
616+
fill: features.FillTypeJIT = None,
617617
) -> torch.Tensor:
618618
if mask.ndim < 3:
619619
mask = mask.unsqueeze(0)
@@ -641,8 +641,8 @@ def rotate_video(
641641
angle: float,
642642
interpolation: InterpolationMode = InterpolationMode.NEAREST,
643643
expand: bool = False,
644-
fill: features.FillTypeJIT = None,
645644
center: Optional[List[float]] = None,
645+
fill: features.FillTypeJIT = None,
646646
) -> torch.Tensor:
647647
return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
648648

@@ -652,8 +652,8 @@ def rotate(
652652
angle: float,
653653
interpolation: InterpolationMode = InterpolationMode.NEAREST,
654654
expand: bool = False,
655-
fill: features.FillTypeJIT = None,
656655
center: Optional[List[float]] = None,
656+
fill: features.FillTypeJIT = None,
657657
) -> features.InputTypeJIT:
658658
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
659659
return rotate_image_tensor(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)

0 commit comments

Comments
 (0)