Skip to content

Commit c8d9ad8

Browse files
speediedanDaniel Dalecarmoccakaushikb11awaelchli
committed
Properly handle parent modules w/ parameters in BaseFinetuning callback (#7931)
Co-authored-by: Daniel Dale <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 0d48daf commit c8d9ad8

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
1616

1717

18+
- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))
19+
20+
1821
## [1.3.5] - 2021-06-08
1922

2023
### Added

pytorch_lightning/callbacks/finetuning.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def on_load_checkpoint(
105105
@staticmethod
106106
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
107107
"""
108-
This function is used to flatten a module or an iterable of modules into a list of its modules.
108+
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
109+
with no children) and parent modules that have parameters directly themselves.
109110
110111
Args:
111112
modules: A given module or an iterable of modules
@@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
121122
else:
122123
_modules = modules.modules()
123124

124-
# Leaf nodes in the graph have no children, so we use that to filter
125-
return [m for m in _modules if not list(m.children())]
125+
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
126+
return [m for m in _modules if not list(m.children()) or m._parameters]
126127

127128
@staticmethod
128129
def filter_params(
@@ -136,15 +137,15 @@ def filter_params(
136137
modules: A given module or an iterable of modules
137138
train_bn: Whether to train BatchNorm module
138139
requires_grad: Whether to create a generator for trainable or non-trainable parameters.
139-
140140
Returns:
141141
Generator
142142
"""
143143
modules = BaseFinetuning.flatten_modules(modules)
144144
for mod in modules:
145145
if isinstance(mod, _BatchNorm) and not train_bn:
146146
continue
147-
for param in mod.parameters():
147+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
148+
for param in mod.parameters(recurse=False):
148149
if param.requires_grad == requires_grad:
149150
yield param
150151

@@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
158159
"""
159160
modules = BaseFinetuning.flatten_modules(modules)
160161
for module in modules:
161-
for param in module.parameters():
162+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
163+
for param in module.parameters(recurse=False):
162164
param.requires_grad = True
163165

164166
@staticmethod
@@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
178180
if isinstance(mod, _BatchNorm) and train_bn:
179181
BaseFinetuning.make_trainable(mod)
180182
else:
181-
for param in mod.parameters():
183+
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
184+
for param in mod.parameters(recurse=False):
182185
param.requires_grad = False
183186

184187
@staticmethod

tests/callbacks/test_finetuning_callback.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ def configure_optimizers(self):
307307
trainer.fit(model)
308308

309309

310-
def test_deep_nested_model():
310+
def test_complex_nested_model():
311+
"""
312+
Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
313+
directly themselves rather than exclusively their submodules containing parameters.
314+
"""
311315

312316
class ConvBlock(nn.Module):
313317

@@ -322,23 +326,39 @@ def forward(self, x):
322326
x = self.act(x)
323327
return self.bn(x)
324328

329+
class ConvBlockParam(nn.Module):
330+
331+
def __init__(self, in_channels, out_channels):
332+
super().__init__()
333+
self.conv = nn.Conv2d(in_channels, out_channels, 3)
334+
self.act = nn.ReLU()
335+
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
336+
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
337+
self.bn = nn.BatchNorm2d(out_channels)
338+
339+
def forward(self, x):
340+
x = self.conv(x)
341+
x = self.act(x)
342+
return self.bn(x)
343+
325344
model = nn.Sequential(
326345
OrderedDict([
327-
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
346+
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
328347
("decoder", ConvBlock(128, 10)),
329348
])
330349
)
331350

332-
# There's 9 leaf layers in that model
333-
assert len(BaseFinetuning.flatten_modules(model)) == 9
351+
# There are 10 leaf modules or parent modules w/ parameters in the test model
352+
assert len(BaseFinetuning.flatten_modules(model)) == 10
334353

335354
BaseFinetuning.freeze(model.encoder, train_bn=True)
336-
assert not model.encoder[0].conv.weight.requires_grad
355+
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
356+
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
337357
assert model.encoder[0].bn.weight.requires_grad
338358

339359
BaseFinetuning.make_trainable(model)
340360
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
341-
# The 8 parameters of the encoder are:
342-
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
361+
# The 9 parameters of the encoder are:
362+
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
343363
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
344-
assert len(encoder_params) == 8
364+
assert len(encoder_params) == 9

0 commit comments

Comments
 (0)