Skip to content
2 changes: 1 addition & 1 deletion src/lightning_lite/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
return decision

@classmethod
Expand Down
15 changes: 13 additions & 2 deletions src/lightning_lite/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
"""Perform a all_gather on all processes."""
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
"""Reduces a boolean decision over distributed processes. By default is analagous to ``all`` from the
standard library, returning ``True`` only if all input decisions evaluate to ``True``. If ``all`` is set to
``False``, it behaves like ``any`` instead.

Args:
decision: A single input decision.
all: Whether to logically emulate ``all`` or ``any``. Defaults to True.

Returns:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = bool(decision == self.world_size)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

def teardown(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
sync_grads: flag that allows users to synchronize gradients for all_gather op
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
"""Reduce a boolean decision across all processes."""
return decision

Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300))

### Fixed

- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))

### Deprecated

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
should_stop, reason = self._evaluate_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
should_stop = trainer.strategy.reduce_boolean_decision(should_stop, all=False)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
return obj

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
return decision

def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
Expand Down
15 changes: 13 additions & 2 deletions src/pytorch_lightning/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,21 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
"""Perform a all_gather on all processes."""
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
"""Reduces a boolean decision over distributed processes. By default is analagous to ``all`` from the
standard library, returning ``True`` only if all input decisions evaluate to ``True``. If ``all`` is set to
``False``, it behaves like ``any`` instead.

Args:
decision: A single input decision.
all: Whether to logically emulate ``all`` or ``any``. Defaults to True.

Returns:
bool: The reduced boolean decision.
"""
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
decision = bool(decision == self.world_size)
decision = bool(decision == self.world_size) if all else bool(decision)
return decision

@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
sync_grads: flag that allows users to synchronize gradients for all_gather op
"""

def reduce_boolean_decision(self, decision: bool) -> bool:
def reduce_boolean_decision(self, decision: bool, all: bool = True) -> bool:
"""Reduce a boolean decision across all processes."""
return decision

Expand Down
37 changes: 27 additions & 10 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,21 @@ def test_early_stopping_mode_options():


class EarlyStoppingModel(BoringModel):
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool):
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool, dist_diverge_epoch: Optional[int] = None):
super().__init__()
self.expected_end_epoch = expected_end_epoch
self.early_stop_on_train = early_stop_on_train
self.dist_diverge_epoch = dist_diverge_epoch

def _dist_diverge(self):
should_diverge = (
self.dist_diverge_epoch and self.current_epoch >= self.dist_diverge_epoch and self.trainer.global_rank == 0
)
return 10 if should_diverge else None

def _epoch_end(self) -> None:
losses = [8, 4, 2, 3, 4, 5, 8, 10]
loss = losses[self.current_epoch]
loss = self._dist_diverge() or losses[self.current_epoch]
self.log("abc", torch.tensor(loss))
self.log("cba", torch.tensor(0))

Expand All @@ -365,20 +372,28 @@ def on_train_end(self) -> None:


@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, check_on_train_epoch_end, strategy, devices",
"callbacks, expected_stop_epoch, check_on_train_epoch_end, strategy, devices, dist_diverge_epoch",
[
([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1),
([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1),
pytest.param([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
pytest.param([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, **_SPAWN_MARK),
([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1),
([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1),
([EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, None, 1, None),
([EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, None, 1, None),
pytest.param(
[EarlyStopping("abc", patience=1), EarlyStopping("cba")], 2, False, "ddp_spawn", 2, 2, **_SPAWN_MARK
),
pytest.param(
[EarlyStopping("abc"), EarlyStopping("cba", patience=3)], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK
),
pytest.param(
[EarlyStopping("cba", patience=3), EarlyStopping("abc")], 3, False, "ddp_spawn", 2, None, **_SPAWN_MARK
),
([EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)], 3, True, None, 1, None),
([EarlyStopping("cba", **_ES_CHECK_P3), EarlyStopping("abc", **_ES_CHECK)], 3, True, None, 1, None),
pytest.param(
[EarlyStopping("abc", **_ES_CHECK), EarlyStopping("cba", **_ES_CHECK_P3)],
3,
True,
"ddp_spawn",
2,
None,
**_SPAWN_MARK,
),
pytest.param(
Expand All @@ -387,6 +402,7 @@ def on_train_end(self) -> None:
True,
"ddp_spawn",
2,
None,
**_SPAWN_MARK,
),
],
Expand All @@ -398,10 +414,11 @@ def test_multiple_early_stopping_callbacks(
check_on_train_epoch_end: bool,
strategy: Optional[str],
devices: int,
dist_diverge_epoch: Optional[int],
):
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""

model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end)
model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end, dist_diverge_epoch=dist_diverge_epoch)

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down