Skip to content

Commit 24c3a3f

Browse files
authored
Add possibility for custom naming when using multiple dataloaders (#6274)
1 parent 38274b9 commit 24c3a3f

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1616

1717

18+
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
19+
20+
1821
### Changed
1922

2023
- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))

pytorch_lightning/core/lightning.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def log(
226226
sync_dist: bool = False,
227227
sync_dist_op: Union[Any, str] = 'mean',
228228
sync_dist_group: Optional[Any] = None,
229+
add_dataloader_idx: bool = True,
229230
):
230231
"""
231232
Log a key, value
@@ -260,7 +261,10 @@ def log(
260261
enable_graph: if True, will not auto detach the graph
261262
sync_dist: if True, reduces the metric across GPUs/TPUs
262263
sync_dist_op: the op to sync across GPUs/TPUs
263-
sync_dist_group: the ddp group
264+
sync_dist_group: the ddp group to sync across
265+
add_dataloader_idx: if True, appends the index of the current dataloader to
266+
the name (when using multiple). If False, user needs to give unique names for
267+
each dataloader to not mix values
264268
"""
265269
if self._results is not None:
266270
# in any epoch end can't log step metrics (only epoch metric)
@@ -292,6 +296,9 @@ def log(
292296

293297
training_type_plugin = self.trainer.training_type_plugin
294298

299+
# Determine if dataloader index should be added
300+
dataloader_idx = self._current_dataloader_idx if add_dataloader_idx else None
301+
295302
self._results.log(
296303
name,
297304
value,
@@ -307,7 +314,7 @@ def log(
307314
sync_dist_op,
308315
sync_dist_group,
309316
training_type_plugin.reduce,
310-
self._current_dataloader_idx,
317+
dataloader_idx,
311318
self.device,
312319
)
313320

@@ -325,6 +332,7 @@ def log_dict(
325332
sync_dist: bool = False,
326333
sync_dist_op: Union[Any, str] = 'mean',
327334
sync_dist_group: Optional[Any] = None,
335+
add_dataloader_idx: bool = True,
328336
):
329337
"""
330338
Log a dictonary of values at once
@@ -346,7 +354,10 @@ def log_dict(
346354
enable_graph: if True, will not auto detach the graph
347355
sync_dist: if True, reduces the metric across GPUs/TPUs
348356
sync_dist_op: the op to sync across GPUs/TPUs
349-
sync_dist_group: the ddp group:
357+
sync_dist_group: the ddp group sync across
358+
add_dataloader_idx: if True, appends the index of the current dataloader to
359+
the name (when using multiple). If False, user needs to give unique names for
360+
each dataloader to not mix values
350361
"""
351362
for k, v in dictionary.items():
352363
self.log(
@@ -363,6 +374,7 @@ def log_dict(
363374
sync_dist_op=sync_dist_op,
364375
tbptt_pad_token=tbptt_pad_token,
365376
tbptt_reduce_fx=tbptt_reduce_fx,
377+
add_dataloader_idx=add_dataloader_idx
366378
)
367379

368380
def write_prediction(

tests/trainer/logging_/test_logger_connector.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,41 @@ def training_step(self, *args, **kwargs):
471471
)
472472
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
473473
trainer.fit(model)
474+
475+
476+
@pytest.mark.parametrize("add_dataloader_idx", [False, True])
477+
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
478+
""" test that auto_add_dataloader_idx argument works """
479+
480+
class TestModel(BoringModel):
481+
def val_dataloader(self):
482+
dl = super().val_dataloader()
483+
return [dl, dl]
484+
485+
def validation_step(self, *args, **kwargs):
486+
output = super().validation_step(*args[:-1], **kwargs)
487+
if add_dataloader_idx:
488+
name = "val_loss"
489+
else:
490+
name = f"val_loss_custom_naming_{args[-1]}"
491+
492+
self.log(name, output["x"], add_dataloader_idx=add_dataloader_idx)
493+
return output
494+
495+
model = TestModel()
496+
model.validation_epoch_end = None
497+
498+
trainer = Trainer(
499+
default_root_dir=tmpdir,
500+
max_steps=5
501+
)
502+
trainer.fit(model)
503+
logged = trainer.logged_metrics
504+
505+
# Check that the correct keys exist
506+
if add_dataloader_idx:
507+
assert 'val_loss/dataloader_idx_0' in logged
508+
assert 'val_loss/dataloader_idx_1' in logged
509+
else:
510+
assert 'val_loss_custom_naming_0' in logged
511+
assert 'val_loss_custom_naming_1' in logged

0 commit comments

Comments
 (0)