From d862250246a0f349d4a422355cd14ccaa3e863cd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 11 Feb 2022 01:24:18 +0100 Subject: [PATCH] Refactor early stopping test --- tests/callbacks/test_early_stopping.py | 129 +++++++++---------------- 1 file changed, 45 insertions(+), 84 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 60f1317019292..bb5e9ba5a7349 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math import pickle from typing import List, Optional from unittest import mock @@ -264,100 +265,60 @@ def validation_epoch_end(self, outputs): assert early_stopping.stopped_epoch == expected_stop_epoch -@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) -def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): - """Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when - `early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` == - `min_steps`, and stop. - - IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` - when `early_stopping` is being triggered, - THEN the trainer should continue until reaching - `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. - This test validate this expected behaviour - - IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` - when `early_stopping` is being triggered, - THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. - - Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) - - This test validate those expected behaviours - """ - - _logger.disabled = True - - original_loss_value = 10 - limit_train_batches = 3 - patience = 3 - - class Model(BoringModel): - def __init__(self, step_freeze): - super().__init__() - - self._step_freeze = step_freeze - - self._loss_value = 10.0 - self._eps = 1e-1 - self._count_decrease = 0 - self._values = [] +@pytest.mark.parametrize("limit_train_batches", (3, 5)) +@pytest.mark.parametrize( + ["min_epochs", "min_steps"], + [ + # IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being + # triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop + (0, 10), + # IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is + # being triggered, THEN the trainer should continue until reaching + # `trainer.global_step` == `min_epochs * len(train_dataloader)` + (2, 0), + # IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when + # `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and + # `min_steps` would be reached + (1, 10), + (3, 10), + ], +) +def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps): + if min_steps: + assert limit_train_batches < min_steps + class TestModel(BoringModel): def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def validation_step(self, batch, batch_idx): - return {"test_val_loss": self._loss_value} + self.log("foo", batch_idx) + return super().training_step(batch, batch_idx) - def validation_epoch_end(self, outputs): - _mean = np.mean([x["test_val_loss"] for x in outputs]) - if self.trainer.global_step <= self._step_freeze: - self._count_decrease += 1 - self._loss_value -= self._eps - self._values.append(_mean) - self.log("test_val_loss", _mean) - - model = Model(step_freeze) - model.training_step_end = None - model.test_dataloader = None - early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + es_callback = EarlyStopping("foo") trainer = Trainer( default_root_dir=tmpdir, - callbacks=[early_stop_callback], + callbacks=es_callback, + limit_val_batches=0, limit_train_batches=limit_train_batches, - limit_val_batches=2, - min_steps=min_steps, min_epochs=min_epochs, + min_steps=min_steps, + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, ) - trainer.fit(model) - - # Make sure loss was properly decreased - assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 - - pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] - - # Compute when the latest validation epoch end happened - latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches - if pos_diff % limit_train_batches == 0: - latest_validation_epoch_end += limit_train_batches - - # Compute early stopping latest step - by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience - - # Compute min_epochs latest step - by_min_epochs = min_epochs * limit_train_batches + model = TestModel() - # Make sure the trainer stops for the max of all minimum requirements - assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), ( - trainer.global_step, - max(min_steps, by_early_stopping, by_min_epochs), - step_freeze, - min_steps, - min_epochs, - ) + expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs) + # trigger early stopping directly after the first epoch + side_effect = [(True, "")] * expected_epochs + with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect): + trainer.fit(model) - _logger.disabled = False + assert trainer.should_stop + # epochs continue until min steps are reached + assert trainer.current_epoch == expected_epochs + # steps continue until min steps are reached AND the epoch is exhausted + # stopping mid-epoch is not supported + assert trainer.global_step == limit_train_batches * expected_epochs def test_early_stopping_mode_options():