Skip to content

Commit 2303f9c

Browse files
authored
Fix(Early Stopping): move best score to device (#7959)
1 parent 92a78d5 commit 2303f9c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
284284
- Pass the `stage` argument of `Callback.{setup,teardown}` as a keyword ([#7973](https://github.com/PyTorchLightning/pytorch-lightning/pull/7973))
285285

286286

287+
- Fixed move best score to device in EarlyStopping Callback ([#7959](https://github.com/PyTorchLightning/pytorch-lightning/pull/7959))
288+
289+
287290
## [1.3.6] - 2021-06-15
288291

289292
### Fixed

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def _run_early_stopping_check(self, trainer) -> None:
196196
# when in dev debugging
197197
trainer.dev_debugger.track_early_stopping_history(self, current)
198198

199-
should_stop, reason = self._evalute_stopping_criteria(current)
199+
should_stop, reason = self._evalute_stopping_criteria(current, trainer)
200200

201201
# stop every ddp process if any world process decides to stop
202202
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
@@ -206,7 +206,7 @@ def _run_early_stopping_check(self, trainer) -> None:
206206
if reason and self.verbose:
207207
self._log_info(trainer, reason)
208208

209-
def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
209+
def _evalute_stopping_criteria(self, current: torch.Tensor, trainer: 'pl.Trainer') -> Tuple[bool, str]:
210210
should_stop = False
211211
reason = None
212212
if self.check_finite and not torch.isfinite(current):
@@ -229,7 +229,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
229229
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
230230
" Signaling Trainer to stop."
231231
)
232-
elif self.monitor_op(current - self.min_delta, self.best_score):
232+
elif self.monitor_op(current - self.min_delta, self.best_score.to(trainer.lightning_module.device)):
233233
should_stop = False
234234
reason = self._improvement_message(current)
235235
self.best_score = current

0 commit comments

Comments
 (0)