diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0efc8e5b16658..3e3b736ecf523 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -144,6 +144,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646)) +- Raised a warning instead of forcing `sync_dist=True` on epoch end ([13364](https://github.com/Lightning-AI/lightning/pull/13364)) + + - Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py index 27cb3cb0323b2..9eb88fda4891e 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -24,11 +24,13 @@ from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device from pytorch_lightning.utilities.data import extract_batch_size +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.metrics import metrics_to_scalars from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.warnings import WarningCache +from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache _IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors _OUT_METRIC = Union[Tensor, Dict[str, Tensor]] @@ -522,12 +524,26 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: cache = result_metric._forward_cache elif not on_step and result_metric.meta.on_epoch: if result_metric._computed is None: - # always reduce on epoch end should = result_metric.meta.sync.should - result_metric.meta.sync.should = True + if not result_metric.meta.sync.should and distributed_available(): + # ensure sync happens for FT since during a failure, the metrics are synced and saved to the + # checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous + # run, and on other ranks, they are 0. So we need to make sure they are synced in further training + # to ensure correct calculation. + if _fault_tolerant_training(): + result_metric.meta.sync.should = True + else: + warning_cache.warn( + f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`" + " when logging on epoch level in distributed setting to accumulate the metric across" + " devices.", + category=PossibleUserWarning, + ) result_metric.compute() result_metric.meta.sync.should = should + cache = result_metric._computed + if cache is not None: if not isinstance(cache, Tensor): raise ValueError( @@ -536,6 +552,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]: ) if not result_metric.meta.enable_graph: return cache.detach() + return cache def valid_items(self) -> Generator: diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 12247e27e8c9d..cb8a51c5bf9ba 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -34,7 +34,9 @@ _ResultMetric, _Sync, ) +from pytorch_lightning.utilities.warnings import PossibleUserWarning from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.utils import no_warning_call class DummyMetric(Metric): @@ -456,6 +458,8 @@ def on_train_epoch_end(self) -> None: "limit_val_batches": 0, "accelerator": accelerator, "devices": devices, + "enable_progress_bar": False, + "enable_model_summary": False, } trainer_kwargs.update(kwargs) trainer = Trainer(**trainer_kwargs) @@ -471,7 +475,7 @@ def on_train_epoch_end(self) -> None: ) ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") - trainer = Trainer(**trainer_kwargs, enable_progress_bar=False, enable_model_summary=False) + trainer = Trainer(**trainer_kwargs) trainer.fit(model, ckpt_path=ckpt_path) assert model.has_validated_sum @@ -659,3 +663,22 @@ def on_train_start(self): ) with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"): trainer.fit(model) + + +@pytest.mark.parametrize("distributed_env", [True, False]) +def test_logger_sync_dist(distributed_env): + # self.log('bar', 7, ..., sync_dist=False) + meta = _Metadata("foo", "bar") + meta.sync = _Sync(_should=False) + result_metric = _ResultMetric(metadata=meta, is_tensor=True) + result_metric.update(torch.tensor(7.0), 10) + + warning_ctx = pytest.warns if distributed_env else no_warning_call + + with mock.patch( + "pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available", + return_value=distributed_env, + ): + with warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"): + value = _ResultCollection._get_cache(result_metric, on_step=False) + assert value == 7.0