Skip to content

Commit c8e22b4

Browse files
rohitgr7carmoccaotaj
authored
Avoid raising the sampler warning if num_replicas=1 (#14097)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: otaj <[email protected]>
1 parent 807f9d8 commit c8e22b4

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8989
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))
9090

9191

92+
- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))
93+
94+
9295
- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))
9396

9497

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional
298298

299299
# update docs too once this is resolved
300300
trainer_fn = self.trainer.state.fn
301-
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
301+
if (
302+
isinstance(sampler, DistributedSampler)
303+
and sampler.num_replicas > 1
304+
and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING)
305+
):
302306
rank_zero_warn(
303-
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
304-
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
307+
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`, it is"
308+
" recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated"
305309
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
306310
" some samples to make sure all devices have same batch size in case of uneven inputs.",
307311
category=PossibleUserWarning,

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,19 +526,20 @@ def test_invalid_hook_passed_in_datahook_selector():
526526
dh_selector.get_instance("setup")
527527

528528

529-
def test_eval_distributed_sampler_warning(tmpdir):
529+
@pytest.mark.parametrize("devices, warn_context", [(1, no_warning_call), (2, pytest.warns)])
530+
def test_eval_distributed_sampler_warning(devices, warn_context):
530531
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""
531532

532533
model = BoringModel()
533-
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
534+
trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu")
534535
trainer._data_connector.attach_data(model)
535536

536537
trainer.state.fn = TrainerFn.VALIDATING
537-
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
538+
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
538539
trainer.reset_val_dataloader(model)
539540

540541
trainer.state.fn = TrainerFn.TESTING
541-
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
542+
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
542543
trainer.reset_test_dataloader(model)
543544

544545

0 commit comments

Comments
 (0)