diff --git a/CHANGELOG.md b/CHANGELOG.md index 0082201aa1cf9..9ae371651fcef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- +- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) - diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 03b268f714a74..e292cd961711a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,7 +196,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: ): # short circuit if metric not present return - current = logs.get(self.monitor) + current = logs[self.monitor].squeeze() should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 9b20b96778e65..da200cc336504 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -469,3 +469,16 @@ def validation_step(self, batch, batch_idx): assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) else: assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 + + +def test_early_stopping_squeezes(): + early_stopping = EarlyStopping(monitor="foo") + trainer = Trainer() + trainer.callback_metrics["foo"] = torch.tensor([[[0]]]) + + with mock.patch( + "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "") + ) as es_mock: + early_stopping._run_early_stopping_check(trainer) + + es_mock.assert_called_once_with(torch.tensor(0))