Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled using any Sampler in distributed environment in Lite ([#13646](https://github.com/PyTorchLightning/pytorch-lightning/pull/13646))


- Raised a warning instead of forcing `sync_dist=True` on epoch end ([13364](https://github.com/Lightning-AI/lightning/pull/13364))


- Updated `val_check_interval`(int) to consider total train batches processed instead of `_batches_that_stepped` for validation check during training ([#12832](https://github.com/Lightning-AI/lightning/pull/12832)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.warnings import PossibleUserWarning, WarningCache

_IN_METRIC = Union[Metric, Tensor] # Do not include scalars as they were converted to tensors
_OUT_METRIC = Union[Tensor, Dict[str, Tensor]]
Expand Down Expand Up @@ -522,12 +524,26 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
cache = result_metric._forward_cache
elif not on_step and result_metric.meta.on_epoch:
if result_metric._computed is None:
# always reduce on epoch end
should = result_metric.meta.sync.should
result_metric.meta.sync.should = True
if not result_metric.meta.sync.should and distributed_available():
# ensure sync happens for FT since during a failure, the metrics are synced and saved to the
# checkpoint, so during restart, metrics on rank 0 are from the accumulated ones from the previous
# run, and on other ranks, they are 0. So we need to make sure they are synced in further training
# to ensure correct calculation.
if _fault_tolerant_training():
result_metric.meta.sync.should = True
else:
warning_cache.warn(
f"It is recommended to use `self.log({result_metric.meta.name!r}, ..., sync_dist=True)`"
" when logging on epoch level in distributed setting to accumulate the metric across"
" devices.",
category=PossibleUserWarning,
)
result_metric.compute()
result_metric.meta.sync.should = should

cache = result_metric._computed

if cache is not None:
if not isinstance(cache, Tensor):
raise ValueError(
Expand All @@ -536,6 +552,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
)
if not result_metric.meta.enable_graph:
return cache.detach()

return cache

def valid_items(self) -> Generator:
Expand Down
25 changes: 24 additions & 1 deletion tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
_ResultMetric,
_Sync,
)
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import no_warning_call


class DummyMetric(Metric):
Expand Down Expand Up @@ -456,6 +458,8 @@ def on_train_epoch_end(self) -> None:
"limit_val_batches": 0,
"accelerator": accelerator,
"devices": devices,
"enable_progress_bar": False,
"enable_model_summary": False,
}
trainer_kwargs.update(kwargs)
trainer = Trainer(**trainer_kwargs)
Expand All @@ -471,7 +475,7 @@ def on_train_epoch_end(self) -> None:
)
ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")

trainer = Trainer(**trainer_kwargs, enable_progress_bar=False, enable_model_summary=False)
trainer = Trainer(**trainer_kwargs)
trainer.fit(model, ckpt_path=ckpt_path)
assert model.has_validated_sum

Expand Down Expand Up @@ -659,3 +663,22 @@ def on_train_start(self):
)
with pytest.raises(ValueError, match=r"compute\(\)` return of.*foo' must be a tensor"):
trainer.fit(model)


@pytest.mark.parametrize("distributed_env", [True, False])
def test_logger_sync_dist(distributed_env):
# self.log('bar', 7, ..., sync_dist=False)
meta = _Metadata("foo", "bar")
meta.sync = _Sync(_should=False)
result_metric = _ResultMetric(metadata=meta, is_tensor=True)
result_metric.update(torch.tensor(7.0), 10)

warning_ctx = pytest.warns if distributed_env else no_warning_call

with mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.result.distributed_available",
return_value=distributed_env,
):
with warning_ctx(PossibleUserWarning, match=r"recommended to use `self.log\('bar', ..., sync_dist=True\)`"):
value = _ResultCollection._get_cache(result_metric, on_step=False)
assert value == 7.0