@@ -228,23 +228,6 @@ def __init__(
228228 # Use default tolerances of `torch.testing.assert_close`
229229 closeness_kwargs = dict (rtol = None , atol = None ),
230230 ),
231- ConsistencyConfig (
232- v2_transforms .ColorJitter ,
233- legacy_transforms .ColorJitter ,
234- [
235- ArgsKwargs (),
236- ArgsKwargs (brightness = 0.1 ),
237- ArgsKwargs (brightness = (0.2 , 0.3 )),
238- ArgsKwargs (contrast = 0.4 ),
239- ArgsKwargs (contrast = (0.5 , 0.6 )),
240- ArgsKwargs (saturation = 0.7 ),
241- ArgsKwargs (saturation = (0.8 , 0.9 )),
242- ArgsKwargs (hue = 0.3 ),
243- ArgsKwargs (hue = (- 0.1 , 0.2 )),
244- ArgsKwargs (brightness = 0.1 , contrast = 0.4 , saturation = 0.5 , hue = 0.3 ),
245- ],
246- closeness_kwargs = {"atol" : 1e-5 , "rtol" : 1e-5 },
247- ),
248231 ConsistencyConfig (
249232 v2_transforms .PILToTensor ,
250233 legacy_transforms .PILToTensor ,
@@ -453,49 +436,6 @@ def test_call_consistency(config, args_kwargs):
453436 )
454437
455438
456- get_params_parametrization = pytest .mark .parametrize (
457- ("config" , "get_params_args_kwargs" ),
458- [
459- pytest .param (
460- next (config for config in CONSISTENCY_CONFIGS if config .prototype_cls is transform_cls ),
461- get_params_args_kwargs ,
462- id = transform_cls .__name__ ,
463- )
464- for transform_cls , get_params_args_kwargs in [
465- (v2_transforms .ColorJitter , ArgsKwargs (brightness = None , contrast = None , saturation = None , hue = None )),
466- (v2_transforms .AutoAugment , ArgsKwargs (5 )),
467- ]
468- ],
469- )
470-
471-
472- @get_params_parametrization
473- def test_get_params_alias (config , get_params_args_kwargs ):
474- assert config .prototype_cls .get_params is config .legacy_cls .get_params
475-
476- if not config .args_kwargs :
477- return
478- args , kwargs = config .args_kwargs [0 ]
479- legacy_transform = config .legacy_cls (* args , ** kwargs )
480- prototype_transform = config .prototype_cls (* args , ** kwargs )
481-
482- assert prototype_transform .get_params is legacy_transform .get_params
483-
484-
485- @get_params_parametrization
486- def test_get_params_jit (config , get_params_args_kwargs ):
487- get_params_args , get_params_kwargs = get_params_args_kwargs
488-
489- torch .jit .script (config .prototype_cls .get_params )(* get_params_args , ** get_params_kwargs )
490-
491- if not config .args_kwargs :
492- return
493- args , kwargs = config .args_kwargs [0 ]
494- transform = config .prototype_cls (* args , ** kwargs )
495-
496- torch .jit .script (transform .get_params )(* get_params_args , ** get_params_kwargs )
497-
498-
499439@pytest .mark .parametrize (
500440 ("config" , "args_kwargs" ),
501441 [
0 commit comments