Skip to content

Commit 8ec7a70

Browse files
authored
allow tolerances in transforms consistency checks (#6774)
1 parent c960273 commit 8ec7a70

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from prototype_common_utils import (
1414
ArgsKwargs,
15+
assert_close,
1516
assert_equal,
1617
make_bounding_box,
1718
make_detection_mask,
@@ -40,13 +41,15 @@ def __init__(
4041
make_images_kwargs=None,
4142
supports_pil=True,
4243
removed_params=(),
44+
closeness_kwargs=None,
4345
):
4446
self.prototype_cls = prototype_cls
4547
self.legacy_cls = legacy_cls
4648
self.args_kwargs = args_kwargs
4749
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
4850
self.supports_pil = supports_pil
4951
self.removed_params = removed_params
52+
self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
5053

5154

5255
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
@@ -491,10 +494,14 @@ def test_signature_consistency(config):
491494
assert prototype_kinds == legacy_kinds
492495

493496

494-
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
497+
def check_call_consistency(
498+
prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
499+
):
495500
if images is None:
496501
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
497502

503+
closeness_kwargs = closeness_kwargs or dict()
504+
498505
for image in images:
499506
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
500507

@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
520527
f"`is_simple_tensor` path in `_transform`."
521528
) from exc
522529

523-
assert_equal(
530+
assert_close(
524531
output_prototype_tensor,
525532
output_legacy_tensor,
526533
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
534+
**closeness_kwargs,
527535
)
528536

529537
try:
@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
536544
f"`features.Image` path in `_transform`."
537545
) from exc
538546

539-
assert_equal(
547+
assert_close(
540548
output_prototype_image,
541549
output_prototype_tensor,
542550
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
551+
**closeness_kwargs,
543552
)
544553

545554
if image.ndim == 3 and supports_pil:
@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
565574
f"`PIL.Image.Image` path in `_transform`."
566575
) from exc
567576

568-
assert_equal(
577+
assert_close(
569578
output_prototype_pil,
570579
output_legacy_pil,
571580
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
581+
**closeness_kwargs,
572582
)
573583

574584

@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
606616
legacy_transform,
607617
images=make_images(**config.make_images_kwargs),
608618
supports_pil=config.supports_pil,
619+
closeness_kwargs=config.closeness_kwargs,
609620
)
610621

611622

0 commit comments

Comments
 (0)