From 92621063397ad3c497acd5992899c4b6e70c0be8 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 14 Oct 2021 20:51:34 +0200 Subject: [PATCH] Fixed use of LightningCLI in computer_vision_fine_tuning.py example --- CHANGELOG.md | 3 +++ .../domain_templates/computer_vision_fine_tuning.py | 10 ++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd726aa448053..0aff62dab5f09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -533,6 +533,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reset `val_dataloader` in `tuner/batch_size_scaling` ([#9857](https://github.com/PyTorchLightning/pytorch-lightning/pull/9857)) +- Fixed use of `LightningCLI` in computer_vision_fine_tuning.py example ([#9934](https://github.com/PyTorchLightning/pytorch-lightning/pull/9934)) + + ## [1.4.9] - 2021-09-30 - Fixed `lr_find` to generate same results on multiple calls ([#9704](https://github.com/PyTorchLightning/pytorch-lightning/pull/9704)) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 631832bd5de4b..c507a6f0e9588 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -34,6 +34,9 @@ Note: See: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html + +To run: + python computer_vision_fine_tuning.py fit """ import logging @@ -265,7 +268,7 @@ def configure_optimizers(self): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_class_arguments(MilestonesFinetuning, "finetuning") + parser.add_lightning_class_args(MilestonesFinetuning, "finetuning") parser.link_arguments("data.batch_size", "model.batch_size") parser.link_arguments("finetuning.milestones", "model.milestones") parser.link_arguments("finetuning.train_bn", "model.train_bn") @@ -277,11 +280,6 @@ def add_arguments_to_parser(self, parser): } ) - def instantiate_trainer(self, *args): - finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning")) - self.trainer_defaults["callbacks"] = [finetuning_callback] - return super().instantiate_trainer(*args) - def cli_main(): MyLightningCLI(TransferLearningModel, CatDogImageDataModule, seed_everything_default=1234)