1212import  torch .testing 
1313from  datasets_utils  import  combinations_grid 
1414from  torch .nn .functional  import  one_hot 
15- from  torch .testing ._comparison  import  (
16-     assert_equal  as  _assert_equal ,
17-     BooleanPair ,
18-     ErrorMeta ,
19-     NonePair ,
20-     NumberPair ,
21-     TensorLikePair ,
22-     UnsupportedInputs ,
23- )
15+ from  torch .testing ._comparison  import  assert_equal  as  _assert_equal , BooleanPair , NonePair , NumberPair , TensorLikePair 
2416from  torchvision .prototype  import  features 
25- from  torchvision .prototype .transforms .functional  import  convert_dtype_image_tensor ,  to_image_tensor 
17+ from  torchvision .prototype .transforms .functional  import  to_image_tensor 
2618from  torchvision .transforms .functional_tensor  import  _max_value  as  get_max_value 
2719
2820__all__  =  [
5446]
5547
5648
57- class  PILImagePair (TensorLikePair ):
49+ class  ImagePair (TensorLikePair ):
5850    def  __init__ (
5951        self ,
6052        actual ,
@@ -64,44 +56,13 @@ def __init__(
6456        allowed_percentage_diff = None ,
6557        ** other_parameters ,
6658    ):
67-         if  not  any (isinstance (input , PIL .Image .Image ) for  input  in  (actual , expected )):
68-             raise  UnsupportedInputs ()
69- 
70-         # This parameter is ignored to enable checking PIL images to tensor images no on the CPU 
71-         other_parameters ["check_device" ] =  False 
59+         if  all (isinstance (input , PIL .Image .Image ) for  input  in  [actual , expected ]):
60+             actual , expected  =  [to_image_tensor (input ) for  input  in  [actual , expected ]]
7261
7362        super ().__init__ (actual , expected , ** other_parameters )
7463        self .agg_method  =  getattr (torch , agg_method ) if  isinstance (agg_method , str ) else  agg_method 
7564        self .allowed_percentage_diff  =  allowed_percentage_diff 
7665
77-     def  _process_inputs (self , actual , expected , * , id , allow_subclasses ):
78-         actual , expected  =  [
79-             to_image_tensor (input ) if  not  isinstance (input , torch .Tensor ) else  features .Image (input )
80-             for  input  in  [actual , expected ]
81-         ]
82-         # This broadcast is needed, because `features.Mask`'s can have a 2D shape, but converting the equivalent PIL 
83-         # image to a tensor adds a singleton leading dimension. 
84-         # Although it looks like this belongs in `self._equalize_attributes`, it has to happen here. 
85-         # `self._equalize_attributes` is called after `super()._compare_attributes` and that has an unconditional 
86-         # shape check that will fail if we don't broadcast before. 
87-         try :
88-             actual , expected  =  torch .broadcast_tensors (actual , expected )
89-         except  RuntimeError :
90-             raise  ErrorMeta (
91-                 AssertionError ,
92-                 f"The image shapes are not broadcastable: { actual .shape } { expected .shape }  ,
93-                 id = id ,
94-             ) from  None 
95-         return  super ()._process_inputs (actual , expected , id = id , allow_subclasses = allow_subclasses )
96- 
97-     def  _equalize_attributes (self , actual , expected ):
98-         if  actual .dtype  !=  expected .dtype :
99-             dtype  =  torch .promote_types (actual .dtype , expected .dtype )
100-             actual  =  convert_dtype_image_tensor (actual , dtype )
101-             expected  =  convert_dtype_image_tensor (expected , dtype )
102- 
103-         return  super ()._equalize_attributes (actual , expected )
104- 
10566    def  compare (self ) ->  None :
10667        actual , expected  =  self .actual , self .expected 
10768
@@ -111,16 +72,24 @@ def compare(self) -> None:
11172        abs_diff  =  torch .abs (actual  -  expected )
11273
11374        if  self .allowed_percentage_diff  is  not None :
114-             percentage_diff  =  ( abs_diff   !=   0 ).to (torch .float ).mean ()
75+             percentage_diff  =  float (( abs_diff . ne ( 0 ).to (torch .float64 ).mean ()) )
11576            if  percentage_diff  >  self .allowed_percentage_diff :
116-                 self ._make_error_meta (AssertionError , "percentage mismatch" )
77+                 raise  self ._make_error_meta (
78+                     AssertionError ,
79+                     f"{ percentage_diff :.1%}  
80+                     f"but only { self .allowed_percentage_diff :.1%}  ,
81+                 )
11782
11883        if  self .agg_method  is  None :
11984            super ()._compare_values (actual , expected )
12085        else :
121-             err  =  self .agg_method (abs_diff .to (torch .float64 ))
122-             if  err  >  self .atol :
123-                 self ._make_error_meta (AssertionError , "aggregated mismatch" )
86+             agg_abs_diff  =  float (self .agg_method (abs_diff .to (torch .float64 )))
87+             if  agg_abs_diff  >  self .atol :
88+                 raise  self ._make_error_meta (
89+                     AssertionError ,
90+                     f"The '{ self .agg_method .__name__ } { agg_abs_diff }  
91+                     f"but only { self .atol }  ,
92+                 )
12493
12594
12695def  assert_close (
@@ -148,7 +117,7 @@ def assert_close(
148117            NonePair ,
149118            BooleanPair ,
150119            NumberPair ,
151-             PILImagePair ,
120+             ImagePair ,
152121            TensorLikePair ,
153122        ),
154123        allow_subclasses = allow_subclasses ,
@@ -167,6 +136,32 @@ def assert_close(
167136assert_equal  =  functools .partial (assert_close , rtol = 0 , atol = 0 )
168137
169138
139+ def  parametrized_error_message (* args , ** kwargs ):
140+     def  to_str (obj ):
141+         if  isinstance (obj , torch .Tensor ) and  obj .numel () >  10 :
142+             return  f"tensor(shape={ list (obj .shape )} { obj .dtype } { obj .device }  
143+         else :
144+             return  repr (obj )
145+ 
146+     if  args  or  kwargs :
147+         postfix  =  "\n " .join (
148+             [
149+                 "" ,
150+                 "Failure happened for the following parameters:" ,
151+                 "" ,
152+                 * [to_str (arg ) for  arg  in  args ],
153+                 * [f"{ name } { to_str (kwarg )}   for  name , kwarg  in  kwargs .items ()],
154+             ]
155+         )
156+     else :
157+         postfix  =  "" 
158+ 
159+     def  wrapper (msg ):
160+         return  msg  +  postfix 
161+ 
162+     return  wrapper 
163+ 
164+ 
170165class  ArgsKwargs :
171166    def  __init__ (self , * args , ** kwargs ):
172167        self .args  =  args 
@@ -656,6 +651,13 @@ def get_marks(self, test_id, args_kwargs):
656651        ]
657652
658653    def  get_closeness_kwargs (self , test_id , * , dtype , device ):
654+         if  not  (isinstance (test_id , tuple ) and  len (test_id ) ==  2 ):
655+             msg  =  "`test_id` should be a `Tuple[Optional[str], str]` denoting the test class and function name" 
656+             if  callable (test_id ):
657+                 msg  +=  ". Did you forget to add the `test_id` fixture to parameters of the test?" 
658+             else :
659+                 msg  +=  f", but got { test_id }  
660+             raise  pytest .UsageError (msg )
659661        if  isinstance (device , torch .device ):
660662            device  =  device .type 
661663        return  self .closeness_kwargs .get ((test_id , dtype , device ), dict ())
0 commit comments