1414import torch .testing
1515from datasets_utils import combinations_grid
1616from torch .nn .functional import one_hot
17- from torch .testing ._comparison import assert_equal as _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair
17+ from torch .testing ._comparison import BooleanPair , NonePair , not_close_error_metas , NumberPair , TensorLikePair
1818from torchvision .prototype import datapoints
1919from torchvision .prototype .transforms .functional import convert_dtype_image_tensor , to_image_tensor
2020from torchvision .transforms .functional_tensor import _max_value as get_max_value
@@ -73,7 +73,7 @@ def compare(self) -> None:
7373 actual , expected = self ._promote_for_comparison (actual , expected )
7474 mae = float (torch .abs (actual - expected ).float ().mean ())
7575 if mae > self .atol :
76- raise self ._make_error_meta (
76+ self ._fail (
7777 AssertionError ,
7878 f"The MAE of the images is { mae } , but only { self .atol } is allowed." ,
7979 )
@@ -99,7 +99,7 @@ def assert_close(
9999 """Superset of :func:`torch.testing.assert_close` with support for PIL vs. tensor image comparison"""
100100 __tracebackhide__ = True
101101
102- _assert_equal (
102+ error_metas = not_close_error_metas (
103103 actual ,
104104 expected ,
105105 pair_types = (
@@ -117,10 +117,12 @@ def assert_close(
117117 check_dtype = check_dtype ,
118118 check_layout = check_layout ,
119119 check_stride = check_stride ,
120- msg = msg ,
121120 ** kwargs ,
122121 )
123122
123+ if error_metas :
124+ raise error_metas [0 ].to_error (msg )
125+
124126
125127assert_equal = functools .partial (assert_close , rtol = 0 , atol = 0 )
126128
0 commit comments