diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a1229d30ede..7b5b93cc9e00c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -207,6 +207,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506)) +- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879)) + + - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index f37e3bb31cc5e..ea508775d126f 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -22,7 +22,6 @@ import torch from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential from torch.optim.optimizer import Optimizer from pytorch_lightning.callbacks.base import Callback @@ -102,11 +101,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - else: _modules = modules.modules() - return list( - filter( - lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules - ) - ) + # Leaf nodes in the graph have no children, so we use that to filter + return [m for m in _modules if not list(m.children())] @staticmethod def filter_params( diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 6aa32934ad771..c11d58cb18543 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict + import pytest import torch from torch import nn @@ -244,3 +246,40 @@ def configure_optimizers(self): trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) trainer.fit(model) + + +def test_deep_nested_model(): + + class ConvBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, 3) + self.act = nn.ReLU() + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return self.bn(x) + + model = nn.Sequential( + OrderedDict([ + ("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))), + ("decoder", ConvBlock(128, 10)), + ]) + ) + + # There's 9 leaf layers in that model + assert len(BaseFinetuning.flatten_modules(model)) == 9 + + BaseFinetuning.freeze(model.encoder, train_bn=True) + assert not model.encoder[0].conv.weight.requires_grad + assert model.encoder[0].bn.weight.requires_grad + + BaseFinetuning.make_trainable(model) + encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True)) + # The 8 parameters of the encoder are: + # conv0.weight, conv0.bias, bn0.weight, bn0.bias + # conv1.weight, conv1.bias, bn1.weight, bn1.bias + assert len(encoder_params) == 8