From 7fdedf9b448c3ae80b98f3cf6cc6c2a367a61c83 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Thu, 10 Jun 2021 14:08:54 -0700 Subject: [PATCH 1/4] bugfix for #7930 w/ associated new test --- pytorch_lightning/callbacks/finetuning.py | 17 ++++--- tests/callbacks/test_finetuning_callback.py | 56 +++++++++++++++++++++ 2 files changed, 66 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index a6c13d1b0c0db..29430d866288d 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -105,7 +105,8 @@ def on_load_checkpoint( @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """ - This function is used to flatten a module or an iterable of modules into a list of its modules. + This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules + with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - else: _modules = modules.modules() - # Leaf nodes in the graph have no children, so we use that to filter - return [m for m in _modules if not list(m.children())] + # Capture all leaf modules as well as parent modules that have parameters directly themsleves + return [m for m in _modules if not list(m.children()) or m._parameters] @staticmethod def filter_params( @@ -136,7 +137,6 @@ def filter_params( modules: A given module or an iterable of modules train_bn: Whether to train BatchNorm module requires_grad: Whether to create a generator for trainable or non-trainable parameters. - Returns: Generator """ @@ -144,7 +144,8 @@ def filter_params( for mod in modules: if isinstance(mod, _BatchNorm) and not train_bn: continue - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): if param.requires_grad == requires_grad: yield param @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> """ modules = BaseFinetuning.flatten_modules(modules) for module in modules: - for param in module.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in module.parameters(recurse=False): param.requires_grad = True @staticmethod @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: if isinstance(mod, _BatchNorm) and train_bn: BaseFinetuning.make_trainable(mod) else: - for param in mod.parameters(): + # recursion could yield duplicate parameters for parent modules w/ parameters so disabling it + for param in mod.parameters(recurse=False): param.requires_grad = False @staticmethod diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 53d34c4645bef..e42b39e2cd0b4 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -342,3 +342,59 @@ def forward(self, x): # conv0.weight, conv0.bias, bn0.weight, bn0.bias # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 8 + + +def test_parent_module_w_param_model(): + """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters + directly themselves rather than exclusively their submodules containing parameters. + """ + + class ConvBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + + class ConvBlockParam(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + # add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling + self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float)) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + + model = nn.Sequential( + OrderedDict([ + ("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))), + ("decoder", ConvBlock(128, 10)), + ]) + ) + + # There are 10 leaf modules or parent modules w/ parameters in the test model + assert len(BaseFinetuning.flatten_modules(model)) == 10 + + BaseFinetuning.freeze(model.encoder, train_bn=True) + assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen + assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen + assert model.encoder[0].bn.weight.requires_grad + + BaseFinetuning.make_trainable(model) + encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) + # The 9 parameters of the encoder are: + # conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param + # conv1.weight, conv1.bias, bn1.weight, bn1.bias + assert len(encoder_params) == 9 From 8b8a8cf8f824a78ef89d4815c2bacbc9fad5e9c5 Mon Sep 17 00:00:00 2001 From: Daniel Dale Date: Fri, 11 Jun 2021 11:53:32 -0700 Subject: [PATCH 2/4] replaced test_deep_nested_model with test_complex_nested_model as it was a superset, updated changelog --- CHANGELOG.md | 1 + tests/callbacks/test_finetuning_callback.py | 41 +-------------------- 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e63484ca8612..d57f6e5a18c85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -208,6 +208,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) - Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index e42b39e2cd0b4..b0b87b3620ac5 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -307,44 +307,7 @@ def configure_optimizers(self): trainer.fit(model) -def test_deep_nested_model(): - - class ConvBlock(nn.Module): - - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, 3) - self.act = nn.ReLU() - self.bn = nn.BatchNorm2d(out_channels) - - def forward(self, x): - x = self.conv(x) - x = self.act(x) - return self.bn(x) - - model = nn.Sequential( - OrderedDict([ - ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), - ("decoder", ConvBlock(128, 10)), - ]) - ) - - # There's 9 leaf layers in that model - assert len(BaseFinetuning.flatten_modules(model)) == 9 - - BaseFinetuning.freeze(model.encoder, train_bn=True) - assert not model.encoder[0].conv.weight.requires_grad - assert model.encoder[0].bn.weight.requires_grad - - BaseFinetuning.make_trainable(model) - encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) - # The 8 parameters of the encoder are: - # conv0.weight, conv0.bias, bn0.weight, bn0.bias - # conv1.weight, conv1.bias, bn1.weight, bn1.bias - assert len(encoder_params) == 8 - - -def test_parent_module_w_param_model(): +def test_complex_nested_model(): """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters directly themselves rather than exclusively their submodules containing parameters. """ @@ -368,7 +331,7 @@ def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3) self.act = nn.ReLU() - # add trivial test parameter to conv block to validate parent (non-leaf) module parameter handling + # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float)) self.bn = nn.BatchNorm2d(out_channels) From ed9008063b568bdd905fd5fc5310075e3bb89e39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 12 Jun 2021 13:27:52 +0200 Subject: [PATCH 3/4] Apply suggestions from code review --- CHANGELOG.md | 3 +++ tests/callbacks/test_finetuning_callback.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d57f6e5a18c85..4e5a358822b86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -208,8 +208,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed + + - Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931)) + - Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) - Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index b0b87b3620ac5..0d6a0e3f0a3d1 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -308,7 +308,8 @@ def configure_optimizers(self): def test_complex_nested_model(): - """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters + """ + Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters directly themselves rather than exclusively their submodules containing parameters. """ From 20d5481c77cfc7c9b33cc4fffd521956be0e9b6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sun, 13 Jun 2021 17:42:32 +0200 Subject: [PATCH 4/4] Update CHANGELOG.md --- CHANGELOG.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b52ede134a7fe..3e2c14457d3b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -218,18 +218,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `DataModule.prepare_data` could only be called on the global rank 0 process ([#7945](https://github.com/PyTorchLightning/pytorch-lightning/pull/7945)) -- Fixed `_check_training_step_output` to be called after `train_step_end` to support more flexible accomodations ([#7868](https://github.com/PyTorchLightning/pytorch-lightning/pull/7868)) - - -- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) - - - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) -- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) - - - Fixed dev debugger memory growing due to tracking events even when disabled ([#7875](https://github.com/PyTorchLightning/pytorch-lightning/pull/7875))