From d7182ad21fe59f3eac2733402d7dacc90b4e7e04 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 18 Jan 2021 17:43:49 +0900 Subject: [PATCH 1/2] Add and fix the docs of BackboneLambdaFinetuningCallback --- docs/source/callbacks.rst | 2 ++ pytorch_lightning/callbacks/finetuning.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index e955ad89fa998..d7039ce7db20e 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -94,6 +94,8 @@ Lightning has a few built-in callbacks. :nosignatures: :template: classtemplate.rst + BaseFinetuningCallback + BackboneLambdaFinetuningCallback Callback EarlyStopping GPUStatsMonitor diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 37e0c50c63334..f9334358a04c8 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -78,6 +78,7 @@ def _recursive_freeze(module: Module, def filter_params(module: Module, train_bn: bool = True) -> Generator: """Yields the trainable parameters of a given module. + Args: module: A given module train_bn: If True, leave the BatchNorm layers in training mode @@ -98,6 +99,7 @@ def filter_params(module: Module, @staticmethod def freeze(module: Module, train_bn: bool = True) -> None: """Freezes the layers up to index n (if n is not None). + Args: module: The module to freeze (at least partially) train_bn: If True, leave the BatchNorm layers in training mode @@ -148,6 +150,7 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback): Finetunne a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and ``should_align`` is set to True, it will align with it for the rest of the training. + Args: unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed. lambda_func: Scheduling function for increasing backbone learning rate. @@ -165,7 +168,9 @@ class BackboneLambdaFinetuningCallback(BaseFinetuningCallback): reaches it. verbose: Display current learning rate for model and backbone round: Precision for displaying learning rate + Example:: + >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import BackboneLambdaFinetuningCallback >>> multiplicative = lambda epoch: 1.5 From 849fa71630c62ea2925657cc51187728052c0fe7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 18 Jan 2021 18:10:39 +0900 Subject: [PATCH 2/2] Apply suggestions from code review --- docs/source/callbacks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index d7039ce7db20e..2588cb0a65de1 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -94,8 +94,8 @@ Lightning has a few built-in callbacks. :nosignatures: :template: classtemplate.rst - BaseFinetuningCallback BackboneLambdaFinetuningCallback + BaseFinetuningCallback Callback EarlyStopping GPUStatsMonitor