Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 45 additions & 84 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import pickle
from typing import List, Optional
from unittest import mock
Expand Down Expand Up @@ -264,100 +265,60 @@ def validation_epoch_end(self, outputs):
assert early_stopping.stopped_epoch == expected_stop_epoch


@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int):
"""Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when
`early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` ==
`min_steps`, and stop.

IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the trainer should continue until reaching
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
This test validate this expected behaviour

IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.

Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)

This test validate those expected behaviours
"""

_logger.disabled = True

original_loss_value = 10
limit_train_batches = 3
patience = 3

class Model(BoringModel):
def __init__(self, step_freeze):
super().__init__()

self._step_freeze = step_freeze

self._loss_value = 10.0
self._eps = 1e-1
self._count_decrease = 0
self._values = []
@pytest.mark.parametrize("limit_train_batches", (3, 5))
@pytest.mark.parametrize(
["min_epochs", "min_steps"],
[
# IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being
# triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop
(0, 10),
# IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is
# being triggered, THEN the trainer should continue until reaching
# `trainer.global_step` == `min_epochs * len(train_dataloader)`
(2, 0),
# IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when
# `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and
# `min_steps` would be reached
(1, 10),
(3, 10),
],
)
def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps):
if min_steps:
assert limit_train_batches < min_steps

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
return {"test_val_loss": self._loss_value}
self.log("foo", batch_idx)
return super().training_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
_mean = np.mean([x["test_val_loss"] for x in outputs])
if self.trainer.global_step <= self._step_freeze:
self._count_decrease += 1
self._loss_value -= self._eps
self._values.append(_mean)
self.log("test_val_loss", _mean)

model = Model(step_freeze)
model.training_step_end = None
model.test_dataloader = None
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
es_callback = EarlyStopping("foo")
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
callbacks=es_callback,
limit_val_batches=0,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
min_steps=min_steps,
min_epochs=min_epochs,
min_steps=min_steps,
logger=False,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)

# Make sure loss was properly decreased
assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6

pos_diff = (np.diff(model._values) == 0).nonzero()[0][0]

# Compute when the latest validation epoch end happened
latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches
if pos_diff % limit_train_batches == 0:
latest_validation_epoch_end += limit_train_batches

# Compute early stopping latest step
by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience

# Compute min_epochs latest step
by_min_epochs = min_epochs * limit_train_batches
model = TestModel()

# Make sure the trainer stops for the max of all minimum requirements
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
trainer.global_step,
max(min_steps, by_early_stopping, by_min_epochs),
step_freeze,
min_steps,
min_epochs,
)
expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs)
# trigger early stopping directly after the first epoch
side_effect = [(True, "")] * expected_epochs
with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect):
trainer.fit(model)

_logger.disabled = False
assert trainer.should_stop
# epochs continue until min steps are reached
assert trainer.current_epoch == expected_epochs
# steps continue until min steps are reached AND the epoch is exhausted
# stopping mid-epoch is not supported
assert trainer.global_step == limit_train_batches * expected_epochs


def test_early_stopping_mode_options():
Expand Down