diff --git a/CHANGELOG.md b/CHANGELOG.md index 76cc6a6486851..ef38828a54d1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -235,6 +235,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941)) +- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) + + ## [1.3.5] - 2021-06-08 ### Added diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index a6c13d1b0c0db..29430d866288d 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -105,7 +105,8 @@ def on_load_checkpoint( @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """ - This function is used to flatten a module or an iterable of modules into a list of its modules. + This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules + with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - else: _modules = modules.modules() - # Leaf nodes in the graph have no children, so we use that to filter - return [m for m in _modules if not list(m.children())] + # Capture all leaf modules as well as parent modules that have parameters directly themsleves + return [m for m in _modules if not list(m.children()) or m._parameters] @staticmethod def filter_params( @@ -136,7 +137,6 @@ def filter_params( modules: A given module or an iterable of modules train_bn: Whether to train BatchNorm module requires_grad: Whether to create a generator for trainable or non-trainable parameters. - Returns: Generator """ @@ -144,7 +144,8 @@ def filter_params( for mod in modules: if isinstance(mod, _BatchNorm) and not train_bn: continue - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): if param.requires_grad == requires_grad: yield param @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> """ modules = BaseFinetuning.flatten_modules(modules) for module in modules: - for param in module.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in module.parameters(recurse=False): param.requires_grad = True @staticmethod @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: if isinstance(mod, _BatchNorm) and train_bn: BaseFinetuning.make_trainable(mod) else: - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): param.requires_grad = False @staticmethod diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 53d34c4645bef..0d6a0e3f0a3d1 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -307,7 +307,11 @@ def configure_optimizers(self): trainer.fit(model) -def test_deep_nested_model(): +def test_complex_nested_model(): + """ + Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters + directly themselves rather than exclusively their submodules containing parameters. + """ class ConvBlock(nn.Module): @@ -322,23 +326,39 @@ def forward(self, x): x = self.act(x) return self.bn(x) + class ConvBlockParam(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling + self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float)) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + model = nn.Sequential( OrderedDict([ - ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), ("decoder", ConvBlock(128, 10)), ]) ) - # There's 9 leaf layers in that model - assert len(BaseFinetuning.flatten_modules(model)) == 9 + # There are 10 leaf modules or parent modules w/ parameters in the test model + assert len(BaseFinetuning.flatten_modules(model)) == 10 BaseFinetuning.freeze(model.encoder, train_bn=True) - assert not model.encoder[0].conv.weight.requires_grad + assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen + assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen assert model.encoder[0].bn.weight.requires_grad BaseFinetuning.make_trainable(model) encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) - # The 8 parameters of the encoder are: - # conv0.weight, conv0.bias, bn0.weight, bn0.bias + # The 9 parameters of the encoder are: + # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param # conv1.weight, conv1.bias, bn1.weight, bn1.bias - assert len(encoder_params) == 8 + assert len(encoder_params) == 9