diff --git a/CHANGELOG.md b/CHANGELOG.md index 670392e60ab96..89a7baff83e08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) +- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274)) + + ### Changed - Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 52bcc213692ac..743e5b5425c1e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -225,6 +225,7 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + add_dataloader_idx: bool = True, ): """ Log a key, value @@ -259,7 +260,10 @@ def log( enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs - sync_dist_group: the ddp group + sync_dist_group: the ddp group to sync across + add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values """ if self._results is not None: # in any epoch end can't log step metrics (only epoch metric) @@ -291,6 +295,9 @@ def log( training_type_plugin = self.trainer.training_type_plugin + # Determine if dataloader index should be added + dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None + self._results.log( name, value, @@ -306,7 +313,7 @@ def log( sync_dist_op, sync_dist_group, training_type_plugin.reduce, - self._current_dataloader_idx, + dataloader_idx, self.device, ) @@ -324,6 +331,7 @@ def log_dict( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + add_dataloader_idx: bool = True, ): """ Log a dictonary of values at once @@ -345,7 +353,10 @@ def log_dict( enable_graph: if True, will not auto detach the graph sync_dist: if True, reduces the metric across GPUs/TPUs sync_dist_op: the op to sync across GPUs/TPUs - sync_dist_group: the ddp group: + sync_dist_group: the ddp group sync across + add_dataloader_idx: if True, appends the index of the current dataloader to + the name (when using multiple). If False, user needs to give unique names for + each dataloader to not mix values """ for k, v in dictionary.items(): self.log( @@ -362,6 +373,7 @@ def log_dict( sync_dist_op=sync_dist_op, tbptt_pad_token=tbptt_pad_token, tbptt_reduce_fx=tbptt_reduce_fx, + add_dataloader_idx=add_dataloader_idx ) def write_prediction( diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 142646fec6cdb..5aa7175e3ef95 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -471,3 +471,41 @@ def training_step(self, *args, **kwargs): ) with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"): trainer.fit(model) + + +@pytest.mark.parametrize("add_dataloader_idx", [False, True]) +def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): + """ test that auto_add_dataloader_idx argument works """ + + class TestModel(BoringModel): + def val_dataloader(self): + dl = super().val_dataloader() + return [dl, dl] + + def validation_step(self, *args, **kwargs): + output = super().validation_step(*args[:-1], **kwargs) + if add_dataloader_idx: + name = "val_loss" + else: + name = f"val_loss_custom_naming_{args[-1]}" + + self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx) + return output + + model = TestModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=5 + ) + trainer.fit(model) + logged = trainer.logged_metrics + + # Check that the correct keys exist + if add_dataloader_idx: + assert 'val_loss/dataloader_idx_0' in logged + assert 'val_loss/dataloader_idx_1' in logged + else: + assert 'val_loss_custom_naming_0' in logged + assert 'val_loss_custom_naming_1' in logged