Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))


- Fixed `BaseFinetuning` callback to properly handle parent modules w/ parameters ([#7931](https://github.com/PyTorchLightning/pytorch-lightning/pull/7931))


## [1.3.5] - 2021-06-08

### Added
Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def on_load_checkpoint(
@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""
This function is used to flatten a module or an iterable of modules into a list of its modules.
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules
with no children) and parent modules that have parameters directly themselves.

Args:
modules: A given module or an iterable of modules
Expand All @@ -121,8 +122,8 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
else:
_modules = modules.modules()

# Leaf nodes in the graph have no children, so we use that to filter
return [m for m in _modules if not list(m.children())]
# Capture all leaf modules as well as parent modules that have parameters directly themsleves
return [m for m in _modules if not list(m.children()) or m._parameters]

@staticmethod
def filter_params(
Expand All @@ -136,15 +137,15 @@ def filter_params(
modules: A given module or an iterable of modules
train_bn: Whether to train BatchNorm module
requires_grad: Whether to create a generator for trainable or non-trainable parameters.

Returns:
Generator
"""
modules = BaseFinetuning.flatten_modules(modules)
for mod in modules:
if isinstance(mod, _BatchNorm) and not train_bn:
continue
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
if param.requires_grad == requires_grad:
yield param

Expand All @@ -158,7 +159,8 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) ->
"""
modules = BaseFinetuning.flatten_modules(modules)
for module in modules:
for param in module.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in module.parameters(recurse=False):
param.requires_grad = True

@staticmethod
Expand All @@ -178,7 +180,8 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn:
if isinstance(mod, _BatchNorm) and train_bn:
BaseFinetuning.make_trainable(mod)
else:
for param in mod.parameters():
# recursion could yield duplicate parameters for parent modules w/ parameters so disabling it
for param in mod.parameters(recurse=False):
param.requires_grad = False

@staticmethod
Expand Down
36 changes: 28 additions & 8 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ def configure_optimizers(self):
trainer.fit(model)


def test_deep_nested_model():
def test_complex_nested_model():
"""
Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters
directly themselves rather than exclusively their submodules containing parameters.
"""

class ConvBlock(nn.Module):

Expand All @@ -322,23 +326,39 @@ def forward(self, x):
x = self.act(x)
return self.bn(x)

class ConvBlockParam(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
return self.bn(x)

model = nn.Sequential(
OrderedDict([
("encoder", nn.Sequential(ConvBlock(3, 64), ConvBlock(64, 128))),
("encoder", nn.Sequential(ConvBlockParam(3, 64), ConvBlock(64, 128))),
("decoder", ConvBlock(128, 10)),
])
)

# There's 9 leaf layers in that model
assert len(BaseFinetuning.flatten_modules(model)) == 9
# There are 10 leaf modules or parent modules w/ parameters in the test model
assert len(BaseFinetuning.flatten_modules(model)) == 10

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
assert model.encoder[0].bn.weight.requires_grad

BaseFinetuning.make_trainable(model)
encoder_params = list(BaseFinetuning.filter_params(model.encoder, train_bn=True))
# The 8 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias
# The 9 parameters of the encoder are:
# conv0.weight, conv0.bias, bn0.weight, bn0.bias, parent_param
# conv1.weight, conv1.bias, bn1.weight, bn1.bias
assert len(encoder_params) == 8
assert len(encoder_params) == 9