From 3e983f687469cad1d1e09eca87114b14a15f3996 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Nov 2021 16:47:54 +0100 Subject: [PATCH 1/4] Squeeze the early stopping monitor --- pytorch_lightning/callbacks/early_stopping.py | 2 +- tests/callbacks/test_early_stopping.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 03b268f714a74..b3c67ffbf849f 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,7 +196,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: ): # short circuit if metric not present return - current = logs.get(self.monitor) + current = logs.get(self.monitor).squeeze() should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 9b20b96778e65..da200cc336504 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -469,3 +469,16 @@ def validation_step(self, batch, batch_idx): assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) else: assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 + + +def test_early_stopping_squeezes(): + early_stopping = EarlyStopping(monitor="foo") + trainer = Trainer() + trainer.callback_metrics["foo"] = torch.tensor([[[0]]]) + + with mock.patch( + "pytorch_lightning.callbacks.EarlyStopping._evaluate_stopping_criteria", return_value=(False, "") + ) as es_mock: + early_stopping._run_early_stopping_check(trainer) + + es_mock.assert_called_once_with(torch.tensor(0)) From 03d762bb824c7fbec31a6bf2b0b45d156a53613f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Nov 2021 16:50:19 +0100 Subject: [PATCH 2/4] Update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0082201aa1cf9..5abb3f1729b67 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- +- Squeeze the early stopping monitor ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) - From e5ae5732c8dddf6e854799da5c52e74bfa967f54 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Nov 2021 18:55:32 +0100 Subject: [PATCH 3/4] Update --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b3c67ffbf849f..e292cd961711a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,7 +196,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: ): # short circuit if metric not present return - current = logs.get(self.monitor).squeeze() + current = logs[self.monitor].squeeze() should_stop, reason = self._evaluate_stopping_criteria(current) # stop every ddp process if any world process decides to stop From 17a3904653326c9016e920abfae324ade0404446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 12 Nov 2021 17:02:17 +0100 Subject: [PATCH 4/4] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5abb3f1729b67..9ae371651fcef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,7 +121,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- Squeeze the early stopping monitor ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) +- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) -