From 0bc92dfafb8dd3bf08484fdd162c9f0b95d9f401 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Mon, 26 Oct 2020 06:14:48 -0700 Subject: [PATCH 01/11] passing batch outputs to on_train_batch_end --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d32f47dbbd485..47208856bdfd8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -557,7 +557,7 @@ def run_training_epoch(self): # hook # TODO: add outputs to batches - self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.on_train_batch_end(epoch_output, batch_output.training_step_output_for_epoch_end, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS From c738b14fbd00b0f9dcf38ff04d2a0cd703083de8 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Mon, 26 Oct 2020 06:21:10 -0700 Subject: [PATCH 02/11] styling --- pytorch_lightning/trainer/training_loop.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47208856bdfd8..2686ca1846d1a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -557,7 +557,10 @@ def run_training_epoch(self): # hook # TODO: add outputs to batches - self.on_train_batch_end(epoch_output, batch_output.training_step_output_for_epoch_end, batch, batch_idx, dataloader_idx) + self.on_train_batch_end( + epoch_output, + batch_output.training_step_output_for_epoch_end, + batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS From 725b5e310b9d96d68c51616f9b3bfc79d5483bc2 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Wed, 11 Nov 2020 07:13:28 -0800 Subject: [PATCH 03/11] updating epoch end logic --- pytorch_lightning/trainer/training_loop.py | 41 +++++++++++----------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0bcf12fb213d5..20b9bfd80f8a1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -250,13 +250,13 @@ def on_train_epoch_start(self, epoch): self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): # hook self.trainer.call_hook('on_batch_end') - self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx) # figure out what to track for epoch end - self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) + self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() @@ -268,12 +268,22 @@ def reset_train_val_dataloaders(self, model): if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) - def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + # track the outputs to reduce at the end of the epoch - for opt_idx, opt_outputs in enumerate(epoch_end_outputs): + for opt_idx, opt_outputs in enumerate(batch_end_outputs): + sample_output = opt_outputs[-1] + + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not(is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result): + continue + # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): opt_outputs = opt_outputs[0] + epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): @@ -542,20 +552,14 @@ def run_training_epoch(self): if batch_output.signal == -1: break - # only track outputs when user implements training_epoch_end - # otherwise we will build up unnecessary memory - epoch_end_outputs = self.process_train_step_outputs( + batch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) - # hook # TODO: add outputs to batches - self.on_train_batch_end( - epoch_output, - batch_output.training_step_output_for_epoch_end, - batch, batch_idx, dataloader_idx) + self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS @@ -888,7 +892,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) - epoch_end_outputs = [] + batch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: @@ -903,14 +907,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate(sample_output["checkpoint_on"]) - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result: - epoch_end_outputs.append(optimizer_idx_outputs) + batch_end_outputs.append(optimizer_idx_outputs) - return epoch_end_outputs + return batch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once From c2fc70f754b9cbbf1f6df11ec33402f4a329d189 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Wed, 18 Nov 2020 14:58:42 -0800 Subject: [PATCH 04/11] also condition on on_train_epoch_end hooks --- pytorch_lightning/trainer/training_loop.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 20b9bfd80f8a1..46fdca6f3d1e8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -277,7 +277,9 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not(is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result): + if not(is_overridden("training_epoch_end", model=self.trainer.get_model()) + or not(is_overridden("on_train_epoch_end", model=self.trainer.get_model())) + or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end From 79b42752a8a1e9b3a0df28c2c28e0e974444445c Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Fri, 20 Nov 2020 12:16:48 -0800 Subject: [PATCH 05/11] more readable --- pytorch_lightning/trainer/training_loop.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 46fdca6f3d1e8..865a66384ffa7 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -276,10 +276,11 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = (is_overridden("training_epoch_end", model=self.trainer.get_model()) + or is_overridden("on_train_epoch_end", model=self.trainer.get_model())) + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not(is_overridden("training_epoch_end", model=self.trainer.get_model()) - or not(is_overridden("on_train_epoch_end", model=self.trainer.get_model())) - or auto_reduce_tng_result): + if not(hook_not_overridden or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end From 09f3c5d098224ae3d638a7c9db53166f22765c45 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Fri, 20 Nov 2020 12:47:38 -0800 Subject: [PATCH 06/11] pep8 --- pytorch_lightning/trainer/training_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 865a66384ffa7..661bc693bedf3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -276,8 +276,9 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - hook_overridden = (is_overridden("training_epoch_end", model=self.trainer.get_model()) - or is_overridden("on_train_epoch_end", model=self.trainer.get_model())) + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.get_model()) or + is_overridden("on_train_epoch_end", model=self.trainer.get_model())) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not(hook_not_overridden or auto_reduce_tng_result): From add4a7eaa607a1ec65585b6245d1fbf699e33cd5 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Fri, 20 Nov 2020 13:39:08 -0800 Subject: [PATCH 07/11] pep8 --- pytorch_lightning/trainer/training_loop.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 46fdca6f3d1e8..f92de79c5fb56 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -276,10 +276,12 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.get_model()) or + is_overridden("on_train_epoch_end", model=self.trainer.get_model())) + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if not(is_overridden("training_epoch_end", model=self.trainer.get_model()) - or not(is_overridden("on_train_epoch_end", model=self.trainer.get_model())) - or auto_reduce_tng_result): + if not(hook_overridden or auto_reduce_tng_result): continue # with 1 step (no tbptt) don't use a sequence at epoch end From 907991b5c342d6e01314206982500082738d039a Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Mon, 30 Nov 2020 12:59:33 -0800 Subject: [PATCH 08/11] readability suggestion accepted Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/training_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3c3696824e182..c1c7ab5a6f5a3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -278,7 +278,8 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end hook_overridden = ( is_overridden("training_epoch_end", model=self.trainer.get_model()) or - is_overridden("on_train_epoch_end", model=self.trainer.get_model())) + is_overridden("on_train_epoch_end", model=self.trainer.get_model()) + ) # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not(hook_overridden or auto_reduce_tng_result): From fb24a4d06bb86e177623cf5f3f4c7366ff892af2 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Thu, 21 Jan 2021 18:56:10 -0800 Subject: [PATCH 09/11] adding test_training_epoch_end_metrics_collection_on_override test --- tests/models/test_hooks.py | 46 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index f3af5b745a380..a8c76b4f1d659 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -17,8 +17,10 @@ import torch from unittest.mock import MagicMock + from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator +import pytorch_lightning as pl from tests.base import EvalModelTemplate, BoringModel @@ -88,6 +90,50 @@ def training_epoch_end(self, outputs): for i in range(num_epochs): assert metrics[f'epoch_metric_{i}'] == i +def test_training_epoch_end_metrics_collection_on_override(tmpdir): + """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ + num_epochs = 1 + + class LoggingCallback(pl.Callback): + def on_train_epoch_end(self, trainer, pl_module): + self.len_outputs = 0 + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.len_outputs = len(outputs[0]) + + class OverriddenModel(EvalModelTemplate): + def on_train_epoch_start(self): + self.num_train_batches = 0 + def training_epoch_end(self, outputs): #Overridden + return + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + class NotOverriddenModel(EvalModelTemplate): + def on_train_epoch_start(self): + self.num_train_batches = 0 + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + + overridden_model = OverriddenModel() + not_overridden_model = NotOverriddenModel() + + callback = LoggingCallback() + trainer = Trainer( + max_epochs=num_epochs, + default_root_dir=tmpdir, + overfit_batches=2, + callbacks = [callback], + ) + + result = trainer.fit(overridden_model) + assert callback.len_outputs == overridden_model.num_train_batches + # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden + + result = trainer.fit(not_overridden_model) + assert callback.len_outputs == 0 + # outputs from on_train_batch_end should be empty + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_transfer_batch_hook(): From fa586aedbafb3022e1610669bd19e04c95b3e014 Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Thu, 21 Jan 2021 19:05:35 -0800 Subject: [PATCH 10/11] fix formatting --- tests/models/test_hooks.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 88f4f9c993646..04b7c9e628197 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -91,31 +91,38 @@ def training_epoch_end(self, outputs): for i in range(num_epochs): assert metrics[f'epoch_metric_{i}'] == i + def test_training_epoch_end_metrics_collection_on_override(tmpdir): """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ num_epochs = 1 class LoggingCallback(pl.Callback): + def on_train_epoch_end(self, trainer, pl_module): self.len_outputs = 0 + def on_train_epoch_end(self, trainer, pl_module, outputs): self.len_outputs = len(outputs[0]) class OverriddenModel(EvalModelTemplate): + def on_train_epoch_start(self): self.num_train_batches = 0 - def training_epoch_end(self, outputs): #Overridden + + def training_epoch_end(self, outputs): #Overridden return + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.num_train_batches += 1 class NotOverriddenModel(EvalModelTemplate): + def on_train_epoch_start(self): self.num_train_batches = 0 + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.num_train_batches += 1 - overridden_model = OverriddenModel() not_overridden_model = NotOverriddenModel() @@ -124,7 +131,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): max_epochs=num_epochs, default_root_dir=tmpdir, overfit_batches=2, - callbacks = [callback], + callbacks=[callback], ) result = trainer.fit(overridden_model) From 00846db6cf85f31aa10e6063ec32ad7fc0f604fc Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Thu, 21 Jan 2021 19:11:45 -0800 Subject: [PATCH 11/11] fix formatting --- tests/models/test_hooks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 04b7c9e628197..62d17515119cd 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -109,7 +109,8 @@ class OverriddenModel(EvalModelTemplate): def on_train_epoch_start(self): self.num_train_batches = 0 - def training_epoch_end(self, outputs): #Overridden + def training_epoch_end(self, outputs): # Overridden + pass return def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):