diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index e955ad89fa998..2588cb0a65de1 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -94,6 +94,8 @@ Lightning has a few built-in callbacks. :nosignatures: :template: classtemplate.rst + BackboneLambdaFinetuningCallback + BaseFinetuningCallback Callback EarlyStopping GPUStatsMonitor diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 37e0c50c63334..f9334358a04c8 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -78,6 +78,7 @@ def _recursive_freeze(module: Module, def filter_params(module: Module, train_bn: bool = True) -> Generator: """Yields the trainable parameters of a given module. + Args: module: A given module train_bn: If True, leave the BatchNorm layers in training mode @@ -98,6 +99,7 @@ def filter_params(module: Module, @staticmethod def freeze(module: Module, train_bn: bool = True) -> None: """Freezes the layers up to index n (if n is not None). + Args: module: The module to freeze (at least partially) train_bn: If True, leave the BatchNorm layers in training mode @@ -148,6 +150,7 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback): Finetunne a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and ``should_align`` is set to True, it will align with it for the rest of the training. + Args: unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed. lambda_func: Scheduling function for increasing backbone learning rate. @@ -165,7 +168,9 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback): reaches it. verbose: Display current learning rate for model and backbone round: Precision for displaying learning rate + Example:: + >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback >>> multiplicative = lambda epoch: 1.5