@@ -307,7 +307,11 @@ def configure_optimizers(self):
307307 trainer .fit (model )
308308
309309
310- def test_deep_nested_model ():
310+ def test_complex_nested_model ():
311+ """
312+ Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
313+ directly themselves rather than exclusively their submodules containing parameters.
314+ """
311315
312316 class ConvBlock (nn .Module ):
313317
@@ -322,23 +326,39 @@ def forward(self, x):
322326 x = self .act (x )
323327 return self .bn (x )
324328
329+ class ConvBlockParam (nn .Module ):
330+
331+ def __init__ (self , in_channels , out_channels ):
332+ super ().__init__ ()
333+ self .conv = nn .Conv2d (in_channels , out_channels , 3 )
334+ self .act = nn .ReLU ()
335+ # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
336+ self .parent_param = nn .Parameter (torch .zeros ((1 ), dtype = torch .float ))
337+ self .bn = nn .BatchNorm2d (out_channels )
338+
339+ def forward (self , x ):
340+ x = self .conv (x )
341+ x = self .act (x )
342+ return self .bn (x )
343+
325344 model = nn .Sequential (
326345 OrderedDict ([
327- ("encoder" , nn .Sequential (ConvBlock (3 , 64 ), ConvBlock (64 , 128 ))),
346+ ("encoder" , nn .Sequential (ConvBlockParam (3 , 64 ), ConvBlock (64 , 128 ))),
328347 ("decoder" , ConvBlock (128 , 10 )),
329348 ])
330349 )
331350
332- # There's 9 leaf layers in that model
333- assert len (BaseFinetuning .flatten_modules (model )) == 9
351+ # There are 10 leaf modules or parent modules w/ parameters in the test model
352+ assert len (BaseFinetuning .flatten_modules (model )) == 10
334353
335354 BaseFinetuning .freeze (model .encoder , train_bn = True )
336- assert not model .encoder [0 ].conv .weight .requires_grad
355+ assert not model .encoder [0 ].conv .weight .requires_grad # Validate a leaf module parameter is frozen
356+ assert not model .encoder [0 ].parent_param .requires_grad # Validate the parent module parameter is frozen
337357 assert model .encoder [0 ].bn .weight .requires_grad
338358
339359 BaseFinetuning .make_trainable (model )
340360 encoder_params = list (BaseFinetuning .filter_params (model .encoder , train_bn = True ))
341- # The 8 parameters of the encoder are:
342- # conv0.weight, conv0.bias, bn0.weight, bn0.bias
361+ # The 9 parameters of the encoder are:
362+ # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
343363 # conv1.weight, conv1.bias, bn1.weight, bn1.bias
344- assert len (encoder_params ) == 8
364+ assert len (encoder_params ) == 9
0 commit comments