Skip to content

Commit 9b6374f

Browse files
scart97SeanNaren
authored andcommitted
Fix finetuning complex models correctly unfreezes. (#6880)
Co-authored-by: Carlos Mocholi <[email protected]> (cherry picked from commit eb15abc)
1 parent 593ae70 commit 9b6374f

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

CHANGELOG.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1313
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
1414

1515

16+
- Added `LightningCLI` class to provide simple reproducibility with minimum boilerplate training cli. ([#4492](https://github.com/PyTorchLightning/pytorch-lightning/pull/4492))
17+
18+
1619
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
1720

1821

22+
- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).
23+
24+
1925
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
2026

2127

@@ -75,6 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7581

7682
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
7783

84+
- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))
7885

7986
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))
8087

@@ -208,13 +215,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
208215
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
209216

210217

211-
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
218+
- Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879))
212219

213220

214-
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
215-
216-
217-
- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))
221+
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
218222

219223

220224
## [1.2.8] - 2021-04-13
@@ -343,6 +347,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
343347
- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
344348
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))
345349
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
350+
- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
346351

347352

348353
## [1.2.0] - 2021-02-18

pytorch_lightning/callbacks/finetuning.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch
2323
from torch.nn import Module
2424
from torch.nn.modules.batchnorm import _BatchNorm
25-
from torch.nn.modules.container import Container, ModuleDict, ModuleList, Sequential
2625
from torch.optim.optimizer import Optimizer
2726

2827
from pytorch_lightning.callbacks.base import Callback
@@ -102,11 +101,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
102101
else:
103102
_modules = modules.modules()
104103

105-
return list(
106-
filter(
107-
lambda m: not isinstance(m, (Container, Sequential, ModuleDict, ModuleList, LightningModule)), _modules
108-
)
109-
)
104+
# Leaf nodes in the graph have no children, so we use that to filter
105+
return [m for m in _modules if not list(m.children())]
110106

111107
@staticmethod
112108
def filter_params(

tests/callbacks/test_finetuning_callback.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from collections import OrderedDict
15+
1416
import pytest
1517
import torch
1618
from torch import nn
@@ -244,3 +246,40 @@ def configure_optimizers(self):
244246

245247
trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True)
246248
trainer.fit(model)
249+
250+
251+
def test_deep_nested_model():
252+
253+
class ConvBlock(nn.Module):
254+
255+
def __init__(self, in_channels, out_channels):
256+
super().__init__()
257+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
258+
self.act = nn.ReLU()
259+
self.bn = nn.BatchNorm2d(out_channels)
260+
261+
def forward(self, x):
262+
x = self.conv(x)
263+
x = self.act(x)
264+
return self.bn(x)
265+
266+
model = nn.Sequential(
267+
OrderedDict([
268+
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
269+
("decoder", ConvBlock(128, 10)),
270+
])
271+
)
272+
273+
# There's 9 leaf layers in that model
274+
assert len(BaseFinetuning.flatten_modules(model)) == 9
275+
276+
BaseFinetuning.freeze(model.encoder, train_bn=True)
277+
assert not model.encoder[0].conv.weight.requires_grad
278+
assert model.encoder[0].bn.weight.requires_grad
279+
280+
BaseFinetuning.make_trainable(model)
281+
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
282+
# The 8 parameters of the encoder are:
283+
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
284+
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
285+
assert len(encoder_params) == 8

0 commit comments

Comments
 (0)