From a611445ee4e1a62ea6c30c3b5f08ce7fec4e6047 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 14 Oct 2022 15:23:59 +0200 Subject: [PATCH] allow tolerances in transforms consistency checks --- test/test_prototype_transforms_consistency.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 7f439fb2608..7d2f1d735ea 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -12,6 +12,7 @@ import torch from prototype_common_utils import ( ArgsKwargs, + assert_close, assert_equal, make_bounding_box, make_detection_mask, @@ -40,6 +41,7 @@ def __init__( make_images_kwargs=None, supports_pil=True, removed_params=(), + closeness_kwargs=None, ): self.prototype_cls = prototype_cls self.legacy_cls = legacy_cls @@ -47,6 +49,7 @@ def __init__( self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS self.supports_pil = supports_pil self.removed_params = removed_params + self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0) # These are here since both the prototype and legacy transform need to be constructed with the same random parameters @@ -491,10 +494,14 @@ def test_signature_consistency(config): assert prototype_kinds == legacy_kinds -def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True): +def check_call_consistency( + prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None +): if images is None: images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS) + closeness_kwargs = closeness_kwargs or dict() + for image in images: image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]" @@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s f"`is_simple_tensor` path in `_transform`." ) from exc - assert_equal( + assert_close( output_prototype_tensor, output_legacy_tensor, msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}", + **closeness_kwargs, ) try: @@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s f"`features.Image` path in `_transform`." ) from exc - assert_equal( + assert_close( output_prototype_image, output_prototype_tensor, msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}", + **closeness_kwargs, ) if image.ndim == 3 and supports_pil: @@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s f"`PIL.Image.Image` path in `_transform`." ) from exc - assert_equal( + assert_close( output_prototype_pil, output_legacy_pil, msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}", + **closeness_kwargs, ) @@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs): legacy_transform, images=make_images(**config.make_images_kwargs), supports_pil=config.supports_pil, + closeness_kwargs=config.closeness_kwargs, )