Skip to content

Commit 10a7111

Browse files
eellisonfmassa
authored andcommitted
remove BC-breaking changes (#1560)
* remove changes that induced BC * Re-enable tests that have been disabled * Remove outdated comment * Remove outdated comment
1 parent b05fc26 commit 10a7111

File tree

2 files changed

+6
-47
lines changed

2 files changed

+6
-47
lines changed

torchvision/models/_utils.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch.jit.annotations import Dict
66

77

8-
class IntermediateLayerGetter(nn.Module):
8+
class IntermediateLayerGetter(nn.ModuleDict):
99
"""
1010
Module wrapper that returns intermediate layers from a model
1111
@@ -45,8 +45,6 @@ class IntermediateLayerGetter(nn.Module):
4545
def __init__(self, model, return_layers):
4646
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
4747
raise ValueError("return_layers are not present in model")
48-
super(IntermediateLayerGetter, self).__init__()
49-
5048
orig_return_layers = return_layers
5149
return_layers = {k: v for k, v in return_layers.items()}
5250
layers = OrderedDict()
@@ -57,33 +55,14 @@ def __init__(self, model, return_layers):
5755
if not return_layers:
5856
break
5957

60-
self.layers = nn.ModuleDict(layers)
58+
super(IntermediateLayerGetter, self).__init__(layers)
6159
self.return_layers = orig_return_layers
6260

6361
def forward(self, x):
6462
out = OrderedDict()
65-
for name, module in self.layers.items():
63+
for name, module in self.items():
6664
x = module(x)
6765
if name in self.return_layers:
6866
out_name = self.return_layers[name]
6967
out[out_name] = x
7068
return out
71-
72-
@torch.jit.ignore
73-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
74-
missing_keys, unexpected_keys, error_msgs):
75-
version = local_metadata.get('version', None)
76-
if (version is None or version < 2):
77-
# now we have a new nesting level for torchscript support
78-
for new_key in self.state_dict().keys():
79-
# remove prefix "layers."
80-
old_key = new_key[len("layers."):]
81-
old_key = prefix + old_key
82-
new_key = prefix + new_key
83-
if old_key in state_dict:
84-
value = state_dict[old_key]
85-
del state_dict[old_key]
86-
state_dict[new_key] = value
87-
super(IntermediateLayerGetter, self)._load_from_state_dict(
88-
state_dict, prefix, local_metadata, strict,
89-
missing_keys, unexpected_keys, error_msgs)

torchvision/models/densenet.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,12 @@ def forward(self, input): # noqa: F811
9090
return new_features
9191

9292

93-
class _DenseBlock(nn.Module):
93+
class _DenseBlock(nn.ModuleDict):
9494
_version = 2
9595
__constants__ = ['layers']
9696

9797
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
9898
super(_DenseBlock, self).__init__()
99-
self.layers = nn.ModuleDict()
10099
for i in range(num_layers):
101100
layer = _DenseLayer(
102101
num_input_features + i * growth_rate,
@@ -105,34 +104,15 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
105104
drop_rate=drop_rate,
106105
memory_efficient=memory_efficient,
107106
)
108-
self.layers['denselayer%d' % (i + 1)] = layer
107+
self.add_module('denselayer%d' % (i + 1), layer)
109108

110109
def forward(self, init_features):
111110
features = [init_features]
112-
for name, layer in self.layers.items():
111+
for name, layer in self.items():
113112
new_features = layer(features)
114113
features.append(new_features)
115114
return torch.cat(features, 1)
116115

117-
@torch.jit.ignore
118-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
119-
missing_keys, unexpected_keys, error_msgs):
120-
version = local_metadata.get('version', None)
121-
if (version is None or version < 2):
122-
# now we have a new nesting level for torchscript support
123-
for new_key in self.state_dict().keys():
124-
# remove prefix "layers."
125-
old_key = new_key[len("layers."):]
126-
old_key = prefix + old_key
127-
new_key = prefix + new_key
128-
if old_key in state_dict:
129-
value = state_dict[old_key]
130-
del state_dict[old_key]
131-
state_dict[new_key] = value
132-
super(_DenseBlock, self)._load_from_state_dict(
133-
state_dict, prefix, local_metadata, strict,
134-
missing_keys, unexpected_keys, error_msgs)
135-
136116

137117
class _Transition(nn.Sequential):
138118
def __init__(self, num_input_features, num_output_features):

0 commit comments

Comments
 (0)