@@ -90,13 +90,12 @@ def forward(self, input): # noqa: F811
9090 return new_features
9191
9292
93- class _DenseBlock (nn .Module ):
93+ class _DenseBlock (nn .ModuleDict ):
9494 _version = 2
9595 __constants__ = ['layers' ]
9696
9797 def __init__ (self , num_layers , num_input_features , bn_size , growth_rate , drop_rate , memory_efficient = False ):
9898 super (_DenseBlock , self ).__init__ ()
99- self .layers = nn .ModuleDict ()
10099 for i in range (num_layers ):
101100 layer = _DenseLayer (
102101 num_input_features + i * growth_rate ,
@@ -105,34 +104,15 @@ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_ra
105104 drop_rate = drop_rate ,
106105 memory_efficient = memory_efficient ,
107106 )
108- self .layers [ 'denselayer%d' % (i + 1 )] = layer
107+ self .add_module ( 'denselayer%d' % (i + 1 ), layer )
109108
110109 def forward (self , init_features ):
111110 features = [init_features ]
112- for name , layer in self .layers . items ():
111+ for name , layer in self .items ():
113112 new_features = layer (features )
114113 features .append (new_features )
115114 return torch .cat (features , 1 )
116115
117- @torch .jit .ignore
118- def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
119- missing_keys , unexpected_keys , error_msgs ):
120- version = local_metadata .get ('version' , None )
121- if (version is None or version < 2 ):
122- # now we have a new nesting level for torchscript support
123- for new_key in self .state_dict ().keys ():
124- # remove prefix "layers."
125- old_key = new_key [len ("layers." ):]
126- old_key = prefix + old_key
127- new_key = prefix + new_key
128- if old_key in state_dict :
129- value = state_dict [old_key ]
130- del state_dict [old_key ]
131- state_dict [new_key ] = value
132- super (_DenseBlock , self )._load_from_state_dict (
133- state_dict , prefix , local_metadata , strict ,
134- missing_keys , unexpected_keys , error_msgs )
135-
136116
137117class _Transition (nn .Sequential ):
138118 def __init__ (self , num_input_features , num_output_features ):
0 commit comments