Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Lightning has a few built-in callbacks.
:nosignatures:
:template: classtemplate.rst

BackboneLambdaFinetuningCallback
BaseFinetuningCallback
Callback
EarlyStopping
GPUStatsMonitor
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _recursive_freeze(module: Module,
def filter_params(module: Module,
train_bn: bool = True) -> Generator:
"""Yields the trainable parameters of a given module.

Args:
module: A given module
train_bn: If True, leave the BatchNorm layers in training mode
Expand All @@ -98,6 +99,7 @@ def filter_params(module: Module,
@staticmethod
def freeze(module: Module, train_bn: bool = True) -> None:
"""Freezes the layers up to index n (if n is not None).

Args:
module: The module to freeze (at least partially)
train_bn: If True, leave the BatchNorm layers in training mode
Expand Down Expand Up @@ -148,6 +150,7 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
Finetunne a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate
and ``should_align`` is set to True, it will align with it for the rest of the training.

Args:
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
lambda_func: Scheduling function for increasing backbone learning rate.
Expand All @@ -165,7 +168,9 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
reaches it.
verbose: Display current learning rate for model and backbone
round: Precision for displaying learning rate

Example::

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback
>>> multiplicative = lambda epoch: 1.5
Expand Down