From cd959b4b2e1e2519333fe77ec7d1638373118eca Mon Sep 17 00:00:00 2001 From: scart97 Date: Wed, 7 Apr 2021 19:39:02 -0300 Subject: [PATCH 1/4] Add test that fails --- tests/callbacks/test_finetuning_callback.py | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 6aa32934ad771..ff70d901120f8 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict import pytest import torch from torch import nn @@ -244,3 +245,40 @@ def configure_optimizers(self): trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) trainer.fit(model) + + +def test_deep_nested_model(): + class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + + model = nn.Sequential( + OrderedDict( + [ + ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), + ("decoder", ConvBlock(128, 10)), + ] + ) + ) + + # There's 9 leaf layers in that model + assert len(BaseFinetuning.flatten_modules(model)) == 9 + + BaseFinetuning.freeze(model.encoder, train_bn=True) + assert not model.encoder[0].conv.weight.requires_grad + assert model.encoder[0].bn.weight.requires_grad + + BaseFinetuning.make_trainable(model) + encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) + # The 8 parameters of the encoder are: + # conv0.weight, conv0.bias, bn0.weight, bn0.bias + # conv1.weight, conv1.bias, bn1.weight, bn1.bias + assert len(encoder_params) == 8 From 39efb117efa0dfcfcbedc1041b0f0adb6a045244 Mon Sep 17 00:00:00 2001 From: scart97 Date: Wed, 7 Apr 2021 19:40:12 -0300 Subject: [PATCH 2/4] flatten_modules now only lists leaf nodes --- pytorch_lightning/callbacks/finetuning.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index f37e3bb31cc5e..d0ebd1fea93cf 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -22,7 +22,6 @@ import torch from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential from torch.optim.optimizer import Optimizer from pytorch_lightning.callbacks.base import Callback @@ -102,11 +101,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - else: _modules = modules.modules() - return list( - filter( - lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules - ) - ) + # Leaf nodes in the graph have no children, so we use that to filter + return list(filter(lambda m: list(m.children()) == [], _modules)) @staticmethod def filter_params( From e9cb8d7d98b13e69f9253a3c90c0c2273e9431c4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 8 Apr 2021 02:05:42 +0200 Subject: [PATCH 3/4] Minor changes. Apply pre-commit --- pytorch_lightning/callbacks/finetuning.py | 2 +- tests/callbacks/test_finetuning_callback.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index d0ebd1fea93cf..ea508775d126f 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -102,7 +102,7 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - _modules = modules.modules() # Leaf nodes in the graph have no children, so we use that to filter - return list(filter(lambda m: list(m.children()) == [], _modules)) + return [m for m in _modules if not list(m.children())] @staticmethod def filter_params( diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index ff70d901120f8..c11d58cb18543 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict + import pytest import torch from torch import nn @@ -248,7 +249,9 @@ def configure_optimizers(self): def test_deep_nested_model(): + class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3) @@ -261,12 +264,10 @@ def forward(self, x): return self.bn(x) model = nn.Sequential( - OrderedDict( - [ - ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), - ("decoder", ConvBlock(128, 10)), - ] - ) + OrderedDict([ + ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), + ("decoder", ConvBlock(128, 10)), + ]) ) # There's 9 leaf layers in that model From b2e45cb370899a0e56e3f5122c300b78cb5c6040 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 8 Apr 2021 02:07:47 +0200 Subject: [PATCH 4/4] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9cab7a63eca6..f0c6ef65916d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -222,6 +222,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) +- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879)) + + - Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))