3434DEFAULT_MAKE_IMAGES_KWARGS = dict (color_spaces = ["RGB" ], extra_dims = [(4 ,)])
3535
3636
37+ class NotScriptableArgsKwargs (ArgsKwargs ):
38+ """
39+ This class is used to mark parameters that render the transform non-scriptable. They still work in eager mode and
40+ thus will be tested there, but will be skipped by the JIT tests.
41+ """
42+
43+ pass
44+
45+
3746class ConsistencyConfig :
3847 def __init__ (
3948 self ,
@@ -73,7 +82,7 @@ def __init__(
7382 prototype_transforms .Resize ,
7483 legacy_transforms .Resize ,
7584 [
76- ArgsKwargs (32 ),
85+ NotScriptableArgsKwargs (32 ),
7786 ArgsKwargs ([32 ]),
7887 ArgsKwargs ((32 , 29 )),
7988 ArgsKwargs ((31 , 28 ), interpolation = prototype_transforms .InterpolationMode .NEAREST ),
@@ -84,8 +93,10 @@ def __init__(
8493 # ArgsKwargs((30, 27), interpolation=0),
8594 # ArgsKwargs((35, 29), interpolation=2),
8695 # ArgsKwargs((34, 25), interpolation=3),
87- ArgsKwargs (31 , max_size = 32 ),
88- ArgsKwargs (30 , max_size = 100 ),
96+ NotScriptableArgsKwargs (31 , max_size = 32 ),
97+ ArgsKwargs ([31 ], max_size = 32 ),
98+ NotScriptableArgsKwargs (30 , max_size = 100 ),
99+ ArgsKwargs ([31 ], max_size = 32 ),
89100 ArgsKwargs ((29 , 32 ), antialias = False ),
90101 ArgsKwargs ((28 , 31 ), antialias = True ),
91102 ],
@@ -121,14 +132,15 @@ def __init__(
121132 prototype_transforms .Pad ,
122133 legacy_transforms .Pad ,
123134 [
124- ArgsKwargs (3 ),
135+ NotScriptableArgsKwargs (3 ),
125136 ArgsKwargs ([3 ]),
126137 ArgsKwargs ([2 , 3 ]),
127138 ArgsKwargs ([3 , 2 , 1 , 4 ]),
128- ArgsKwargs (5 , fill = 1 , padding_mode = "constant" ),
129- ArgsKwargs (5 , padding_mode = "edge" ),
130- ArgsKwargs (5 , padding_mode = "reflect" ),
131- ArgsKwargs (5 , padding_mode = "symmetric" ),
139+ NotScriptableArgsKwargs (5 , fill = 1 , padding_mode = "constant" ),
140+ ArgsKwargs ([5 ], fill = 1 , padding_mode = "constant" ),
141+ NotScriptableArgsKwargs (5 , padding_mode = "edge" ),
142+ NotScriptableArgsKwargs (5 , padding_mode = "reflect" ),
143+ NotScriptableArgsKwargs (5 , padding_mode = "symmetric" ),
132144 ],
133145 ),
134146 ConsistencyConfig (
@@ -170,7 +182,7 @@ def __init__(
170182 ConsistencyConfig (
171183 prototype_transforms .ToPILImage ,
172184 legacy_transforms .ToPILImage ,
173- [ArgsKwargs ()],
185+ [NotScriptableArgsKwargs ()],
174186 make_images_kwargs = dict (
175187 color_spaces = [
176188 "GRAY" ,
@@ -186,7 +198,7 @@ def __init__(
186198 prototype_transforms .Lambda ,
187199 legacy_transforms .Lambda ,
188200 [
189- ArgsKwargs (lambda image : image / 2 ),
201+ NotScriptableArgsKwargs (lambda image : image / 2 ),
190202 ],
191203 # Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
192204 # images given that the transform does nothing but call it anyway.
@@ -380,14 +392,15 @@ def __init__(
380392 [
381393 ArgsKwargs (12 ),
382394 ArgsKwargs ((15 , 17 )),
383- ArgsKwargs (11 , padding = 1 ),
395+ NotScriptableArgsKwargs (11 , padding = 1 ),
396+ ArgsKwargs (11 , padding = [1 ]),
384397 ArgsKwargs ((8 , 13 ), padding = (2 , 3 )),
385398 ArgsKwargs ((14 , 9 ), padding = (0 , 2 , 1 , 0 )),
386399 ArgsKwargs (36 , pad_if_needed = True ),
387400 ArgsKwargs ((7 , 8 ), fill = 1 ),
388- ArgsKwargs (5 , fill = (1 , 2 , 3 )),
401+ NotScriptableArgsKwargs (5 , fill = (1 , 2 , 3 )),
389402 ArgsKwargs (12 ),
390- ArgsKwargs (15 , padding = 2 , padding_mode = "edge" ),
403+ NotScriptableArgsKwargs (15 , padding = 2 , padding_mode = "edge" ),
391404 ArgsKwargs (17 , padding = (1 , 0 ), padding_mode = "reflect" ),
392405 ArgsKwargs (8 , padding = (3 , 0 , 0 , 1 ), padding_mode = "symmetric" ),
393406 ],
@@ -642,6 +655,38 @@ def test_call_consistency(config, args_kwargs):
642655 )
643656
644657
658+ @pytest .mark .parametrize (
659+ ("config" , "args_kwargs" ),
660+ [
661+ pytest .param (
662+ config , args_kwargs , id = f"{ config .legacy_cls .__name__ } -{ idx :0{len (str (len (config .args_kwargs )))}d} "
663+ )
664+ for config in CONSISTENCY_CONFIGS
665+ for idx , args_kwargs in enumerate (config .args_kwargs )
666+ if not isinstance (args_kwargs , NotScriptableArgsKwargs )
667+ ],
668+ )
669+ def test_jit_consistency (config , args_kwargs ):
670+ args , kwargs = args_kwargs
671+
672+ prototype_transform_eager = config .prototype_cls (* args , ** kwargs )
673+ legacy_transform_eager = config .legacy_cls (* args , ** kwargs )
674+
675+ legacy_transform_scripted = torch .jit .script (legacy_transform_eager )
676+ prototype_transform_scripted = torch .jit .script (prototype_transform_eager )
677+
678+ for image in make_images (** config .make_images_kwargs ):
679+ image = image .as_subclass (torch .Tensor )
680+
681+ torch .manual_seed (0 )
682+ output_legacy_scripted = legacy_transform_scripted (image )
683+
684+ torch .manual_seed (0 )
685+ output_prototype_scripted = prototype_transform_scripted (image )
686+
687+ assert_close (output_prototype_scripted , output_legacy_scripted , ** config .closeness_kwargs )
688+
689+
645690class TestContainerTransforms :
646691 """
647692 Since we are testing containers here, we also need some transforms to wrap. Thus, testing a container transform for
0 commit comments