Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from prototype_common_utils import (
ArgsKwargs,
assert_close,
assert_equal,
make_bounding_box,
make_detection_mask,
Expand Down Expand Up @@ -40,13 +41,15 @@ def __init__(
make_images_kwargs=None,
supports_pil=True,
removed_params=(),
closeness_kwargs=None,
):
self.prototype_cls = prototype_cls
self.legacy_cls = legacy_cls
self.args_kwargs = args_kwargs
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
self.supports_pil = supports_pil
self.removed_params = removed_params
self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)


# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
Expand Down Expand Up @@ -491,10 +494,14 @@ def test_signature_consistency(config):
assert prototype_kinds == legacy_kinds


def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
def check_call_consistency(
prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
):
if images is None:
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)

closeness_kwargs = closeness_kwargs or dict()

for image in images:
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"

Expand All @@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`is_simple_tensor` path in `_transform`."
) from exc

assert_equal(
assert_close(
output_prototype_tensor,
output_legacy_tensor,
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
)

try:
Expand All @@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`features.Image` path in `_transform`."
) from exc

assert_equal(
assert_close(
output_prototype_image,
output_prototype_tensor,
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
**closeness_kwargs,
)

if image.ndim == 3 and supports_pil:
Expand All @@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
f"`PIL.Image.Image` path in `_transform`."
) from exc

assert_equal(
assert_close(
output_prototype_pil,
output_legacy_pil,
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
**closeness_kwargs,
)


Expand Down Expand Up @@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
legacy_transform,
images=make_images(**config.make_images_kwargs),
supports_pil=config.supports_pil,
closeness_kwargs=config.closeness_kwargs,
)


Expand Down