1212import torch
1313from prototype_common_utils import (
1414 ArgsKwargs ,
15+ assert_close ,
1516 assert_equal ,
1617 make_bounding_box ,
1718 make_detection_mask ,
@@ -40,13 +41,15 @@ def __init__(
4041 make_images_kwargs = None ,
4142 supports_pil = True ,
4243 removed_params = (),
44+ closeness_kwargs = None ,
4345 ):
4446 self .prototype_cls = prototype_cls
4547 self .legacy_cls = legacy_cls
4648 self .args_kwargs = args_kwargs
4749 self .make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
4850 self .supports_pil = supports_pil
4951 self .removed_params = removed_params
52+ self .closeness_kwargs = closeness_kwargs or dict (rtol = 0 , atol = 0 )
5053
5154
5255# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
@@ -491,10 +494,14 @@ def test_signature_consistency(config):
491494 assert prototype_kinds == legacy_kinds
492495
493496
494- def check_call_consistency (prototype_transform , legacy_transform , images = None , supports_pil = True ):
497+ def check_call_consistency (
498+ prototype_transform , legacy_transform , images = None , supports_pil = True , closeness_kwargs = None
499+ ):
495500 if images is None :
496501 images = make_images (** DEFAULT_MAKE_IMAGES_KWARGS )
497502
503+ closeness_kwargs = closeness_kwargs or dict ()
504+
498505 for image in images :
499506 image_repr = f"[{ tuple (image .shape )} , { str (image .dtype ).rsplit ('.' )[- 1 ]} ]"
500507
@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
520527 f"`is_simple_tensor` path in `_transform`."
521528 ) from exc
522529
523- assert_equal (
530+ assert_close (
524531 output_prototype_tensor ,
525532 output_legacy_tensor ,
526533 msg = lambda msg : f"Tensor image consistency check failed with: \n \n { msg } " ,
534+ ** closeness_kwargs ,
527535 )
528536
529537 try :
@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
536544 f"`features.Image` path in `_transform`."
537545 ) from exc
538546
539- assert_equal (
547+ assert_close (
540548 output_prototype_image ,
541549 output_prototype_tensor ,
542550 msg = lambda msg : f"Output for feature and tensor images is not equal: \n \n { msg } " ,
551+ ** closeness_kwargs ,
543552 )
544553
545554 if image .ndim == 3 and supports_pil :
@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
565574 f"`PIL.Image.Image` path in `_transform`."
566575 ) from exc
567576
568- assert_equal (
577+ assert_close (
569578 output_prototype_pil ,
570579 output_legacy_pil ,
571580 msg = lambda msg : f"PIL image consistency check failed with: \n \n { msg } " ,
581+ ** closeness_kwargs ,
572582 )
573583
574584
@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
606616 legacy_transform ,
607617 images = make_images (** config .make_images_kwargs ),
608618 supports_pil = config .supports_pil ,
619+ closeness_kwargs = config .closeness_kwargs ,
609620 )
610621
611622
0 commit comments