99import unittest
1010import warnings
1111
12+ import pytest
13+
1214
1315def get_available_classification_models ():
1416 # TODO add a registration mechanism to torchvision.models
@@ -79,7 +81,7 @@ def _test_classification_model(self, name, input_shape, dev):
7981 # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
8082 x = torch .rand (input_shape ).to (device = dev )
8183 out = model (x )
82- self .assertExpected (out .cpu (), prec = 0.1 , strip_suffix = f"_ { dev } " )
84+ self .assertExpected (out .cpu (), name , prec = 0.1 )
8385 self .assertEqual (out .shape [- 1 ], 50 )
8486 self .check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
8587
@@ -88,7 +90,7 @@ def _test_classification_model(self, name, input_shape, dev):
8890 out = model (x )
8991 # See autocast_flaky_numerics comment at top of file.
9092 if name not in autocast_flaky_numerics :
91- self .assertExpected (out .cpu (), prec = 0.1 , strip_suffix = f"_ { dev } " )
93+ self .assertExpected (out .cpu (), name , prec = 0.1 )
9294 self .assertEqual (out .shape [- 1 ], 50 )
9395
9496 def _test_segmentation_model (self , name , dev ):
@@ -104,17 +106,16 @@ def _test_segmentation_model(self, name, dev):
104106
105107 def check_out (out ):
106108 prec = 0.01
107- strip_suffix = f"_{ dev } "
108109 try :
109110 # We first try to assert the entire output if possible. This is not
110111 # only the best way to assert results but also handles the cases
111112 # where we need to create a new expected result.
112- self .assertExpected (out .cpu (), prec = prec , strip_suffix = strip_suffix )
113+ self .assertExpected (out .cpu (), name , prec = prec )
113114 except AssertionError :
114115 # Unfortunately some segmentation models are flaky with autocast
115116 # so instead of validating the probability scores, check that the class
116117 # predictions match.
117- expected_file = self ._get_expected_file (strip_suffix = strip_suffix )
118+ expected_file = self ._get_expected_file (name )
118119 expected = torch .load (expected_file )
119120 self .assertEqual (out .argmax (dim = 1 ), expected .argmax (dim = 1 ), prec = prec )
120121 return False # Partial validation performed
@@ -189,18 +190,18 @@ def compute_mean_std(tensor):
189190
190191 output = map_nested_tensor_object (out , tensor_map_fn = compact )
191192 prec = 0.01
192- strip_suffix = f"_{ dev } "
193193 try :
194194 # We first try to assert the entire output if possible. This is not
195195 # only the best way to assert results but also handles the cases
196196 # where we need to create a new expected result.
197- self .assertExpected (output , prec = prec , strip_suffix = strip_suffix )
197+ self .assertExpected (output , name , prec = prec )
198+ raise AssertionError
198199 except AssertionError :
199200 # Unfortunately detection models are flaky due to the unstable sort
200201 # in NMS. If matching across all outputs fails, use the same approach
201202 # as in NMSTester.test_nms_cuda to see if this is caused by duplicate
202203 # scores.
203- expected_file = self ._get_expected_file (strip_suffix = strip_suffix )
204+ expected_file = self ._get_expected_file (name )
204205 expected = torch .load (expected_file )
205206 self .assertEqual (output [0 ]["scores" ], expected [0 ]["scores" ], prec = prec )
206207
@@ -430,50 +431,35 @@ def test_generalizedrcnn_transform_repr(self):
430431_devs = [torch .device ("cpu" ), torch .device ("cuda" )] if torch .cuda .is_available () else [torch .device ("cpu" )]
431432
432433
433- for model_name in get_available_classification_models ():
434- for dev in _devs :
435- # for-loop bodies don't define scopes, so we have to save the variables
436- # we want to close over in some way
437- def do_test (self , model_name = model_name , dev = dev ):
438- input_shape = (1 , 3 , 224 , 224 )
439- if model_name in ['inception_v3' ]:
440- input_shape = (1 , 3 , 299 , 299 )
441- self ._test_classification_model (model_name , input_shape , dev )
442-
443- setattr (ModelTester , f"test_{ model_name } _{ dev } " , do_test )
444-
445-
446- for model_name in get_available_segmentation_models ():
447- for dev in _devs :
448- # for-loop bodies don't define scopes, so we have to save the variables
449- # we want to close over in some way
450- def do_test (self , model_name = model_name , dev = dev ):
451- self ._test_segmentation_model (model_name , dev )
434+ @pytest .mark .parametrize ('model_name' , get_available_classification_models ())
435+ @pytest .mark .parametrize ('dev' , _devs )
436+ def test_classification_model (model_name , dev ):
437+ input_shape = (1 , 3 , 299 , 299 ) if model_name == 'inception_v3' else (1 , 3 , 224 , 224 )
438+ ModelTester ()._test_classification_model (model_name , input_shape , dev )
452439
453- setattr (ModelTester , f"test_{ model_name } _{ dev } " , do_test )
454440
441+ @pytest .mark .parametrize ('model_name' , get_available_segmentation_models ())
442+ @pytest .mark .parametrize ('dev' , _devs )
443+ def test_segmentation_model (model_name , dev ):
444+ ModelTester ()._test_segmentation_model (model_name , dev )
455445
456- for model_name in get_available_detection_models ():
457- for dev in _devs :
458- # for-loop bodies don't define scopes, so we have to save the variables
459- # we want to close over in some way
460- def do_test (self , model_name = model_name , dev = dev ):
461- self ._test_detection_model (model_name , dev )
462446
463- setattr (ModelTester , f"test_{ model_name } _{ dev } " , do_test )
447+ @pytest .mark .parametrize ('model_name' , get_available_detection_models ())
448+ @pytest .mark .parametrize ('dev' , _devs )
449+ def test_detection_model (model_name , dev ):
450+ ModelTester ()._test_detection_model (model_name , dev )
464451
465- def do_validation_test (self , model_name = model_name ):
466- self ._test_detection_model_validation (model_name )
467452
468- setattr (ModelTester , "test_" + model_name + "_validation" , do_validation_test )
453+ @pytest .mark .parametrize ('model_name' , get_available_detection_models ())
454+ def test_detection_model_validation (model_name ):
455+ ModelTester ()._test_detection_model_validation (model_name )
469456
470457
471- for model_name in get_available_video_models ():
472- for dev in _devs :
473- def do_test ( self , model_name = model_name , dev = dev ):
474- self ._test_video_model (model_name , dev )
458+ @ pytest . mark . parametrize ( ' model_name' , get_available_video_models ())
459+ @ pytest . mark . parametrize ( ' dev' , _devs )
460+ def test_video_model ( model_name , dev ):
461+ ModelTester () ._test_video_model (model_name , dev )
475462
476- setattr (ModelTester , f"test_{ model_name } _{ dev } " , do_test )
477463
478464if __name__ == '__main__' :
479- unittest .main ()
465+ pytest .main ([ __file__ ] )
0 commit comments