Skip to content

Commit 30f31d3

Browse files
akihironittacarmoccas-rogmergify[bot]
authored
docs: Add BackboneLambdaFinetuningCallback (#5553)
* Add and fix the docs of BackboneLambdaFinetuningCallback * Apply suggestions from code review Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent db78422 commit 30f31d3

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

docs/source/callbacks.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ Lightning has a few built-in callbacks.
9494
:nosignatures:
9595
:template: classtemplate.rst
9696

97+
BackboneLambdaFinetuningCallback
98+
BaseFinetuningCallback
9799
Callback
98100
EarlyStopping
99101
GPUStatsMonitor

pytorch_lightning/callbacks/finetuning.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _recursive_freeze(module: Module,
7878
def filter_params(module: Module,
7979
train_bn: bool = True) -> Generator:
8080
"""Yields the trainable parameters of a given module.
81+
8182
Args:
8283
module: A given module
8384
train_bn: If True, leave the BatchNorm layers in training mode
@@ -98,6 +99,7 @@ def filter_params(module: Module,
9899
@staticmethod
99100
def freeze(module: Module, train_bn: bool = True) -> None:
100101
"""Freezes the layers up to index n (if n is not None).
102+
101103
Args:
102104
module: The module to freeze (at least partially)
103105
train_bn: If True, leave the BatchNorm layers in training mode
@@ -148,6 +150,7 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
148150
Finetunne a backbone model based on a learning rate user-defined scheduling.
149151
When the backbone learning rate reaches the current model learning rate
150152
and ``should_align`` is set to True, it will align with it for the rest of the training.
153+
151154
Args:
152155
unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed.
153156
lambda_func: Scheduling function for increasing backbone learning rate.
@@ -165,7 +168,9 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback):
165168
reaches it.
166169
verbose: Display current learning rate for model and backbone
167170
round: Precision for displaying learning rate
171+
168172
Example::
173+
169174
>>> from pytorch_lightning import Trainer
170175
>>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback
171176
>>> multiplicative = lambda epoch: 1.5

0 commit comments

Comments
 (0)