1+ import contextlib
12import functools
23import io
34import operator
45import os
6+ import pkgutil
7+ import sys
58import traceback
69import warnings
710from collections import OrderedDict
1417from common_utils import map_nested_tensor_object , freeze_rng_state , set_rng_seed , cpu_and_gpu , needs_cuda
1518from torchvision import models
1619
17-
1820ACCEPT = os .getenv ("EXPECTTEST_ACCEPT" , "0" ) == "1"
1921
2022
@@ -23,6 +25,51 @@ def get_models_from_module(module):
2325 return [v for k , v in module .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
2426
2527
28+ @pytest .fixture
29+ def disable_weight_loading (mocker ):
30+ """When testing models, the two slowest operations are the downloading of the weights to a file and loading them
31+ into the model. Unless, you want to test against specific weights, these steps can be disabled without any
32+ drawbacks.
33+
34+ Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
35+ through all models in `torchvision.models` and will patch all occurrences of the function
36+ `download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
37+ no-ops.
38+
39+ .. warning:
40+
41+ Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
42+ fixture if you want to compare the model output against reference values.
43+
44+ """
45+ starting_point = models
46+ function_name = "load_state_dict_from_url"
47+ method_name = "load_state_dict"
48+
49+ module_names = {info .name for info in pkgutil .walk_packages (starting_point .__path__ , f"{ starting_point .__name__ } ." )}
50+ targets = {f"torchvision._internally_replaced_utils.{ function_name } " , f"torch.nn.Module.{ method_name } " }
51+ for name in module_names :
52+ module = sys .modules .get (name )
53+ if not module :
54+ continue
55+
56+ if function_name in module .__dict__ :
57+ targets .add (f"{ module .__name__ } .{ function_name } " )
58+
59+ targets .update (
60+ {
61+ f"{ module .__name__ } .{ obj .__name__ } .{ method_name } "
62+ for obj in module .__dict__ .values ()
63+ if isinstance (obj , type ) and issubclass (obj , nn .Module ) and method_name in obj .__dict__
64+ }
65+ )
66+
67+ for target in targets :
68+ # See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
69+ with contextlib .suppress (AttributeError ):
70+ mocker .patch (target )
71+
72+
2673def _get_expected_file (name = None ):
2774 # Determine expected file based on environment
2875 expected_file_base = get_relative_path (os .path .realpath (__file__ ), "expect" )
@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):
762809
763810
764811@pytest .mark .parametrize ("model_fn" , get_models_from_module (models .detection ))
765- def test_detection_model_trainable_backbone_layers (model_fn ):
812+ def test_detection_model_trainable_backbone_layers (model_fn , disable_weight_loading ):
766813 model_name = model_fn .__name__
767814 max_trainable = _model_tests_values [model_name ]["max_trainable" ]
768815 n_trainable_params = []
0 commit comments