|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import logging |
| 15 | +import math |
15 | 16 | import pickle |
16 | 17 | from typing import List, Optional |
17 | 18 | from unittest import mock |
@@ -264,100 +265,60 @@ def validation_epoch_end(self, outputs): |
264 | 265 | assert early_stopping.stopped_epoch == expected_stop_epoch |
265 | 266 |
|
266 | 267 |
|
267 | | -@pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) |
268 | | -def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): |
269 | | - """Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when |
270 | | - `early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` == |
271 | | - `min_steps`, and stop. |
272 | | -
|
273 | | - IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` |
274 | | - when `early_stopping` is being triggered, |
275 | | - THEN the trainer should continue until reaching |
276 | | - `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. |
277 | | - This test validate this expected behaviour |
278 | | -
|
279 | | - IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` |
280 | | - when `early_stopping` is being triggered, |
281 | | - THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. |
282 | | -
|
283 | | - Caveat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) |
284 | | -
|
285 | | - This test validate those expected behaviours |
286 | | - """ |
287 | | - |
288 | | - _logger.disabled = True |
289 | | - |
290 | | - original_loss_value = 10 |
291 | | - limit_train_batches = 3 |
292 | | - patience = 3 |
293 | | - |
294 | | - class Model(BoringModel): |
295 | | - def __init__(self, step_freeze): |
296 | | - super().__init__() |
297 | | - |
298 | | - self._step_freeze = step_freeze |
299 | | - |
300 | | - self._loss_value = 10.0 |
301 | | - self._eps = 1e-1 |
302 | | - self._count_decrease = 0 |
303 | | - self._values = [] |
| 268 | +@pytest.mark.parametrize("limit_train_batches", (3, 5)) |
| 269 | +@pytest.mark.parametrize( |
| 270 | + ["min_epochs", "min_steps"], |
| 271 | + [ |
| 272 | + # IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being |
| 273 | + # triggered, THEN the trainer should continue until reaching `trainer.global_step == min_steps` and stop |
| 274 | + (0, 10), |
| 275 | + # IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is |
| 276 | + # being triggered, THEN the trainer should continue until reaching |
| 277 | + # `trainer.global_step` == `min_epochs * len(train_dataloader)` |
| 278 | + (2, 0), |
| 279 | + # IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when |
| 280 | + # `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and |
| 281 | + # `min_steps` would be reached |
| 282 | + (1, 10), |
| 283 | + (3, 10), |
| 284 | + ], |
| 285 | +) |
| 286 | +def test_min_epochs_min_steps_global_step(tmpdir, limit_train_batches, min_epochs, min_steps): |
| 287 | + if min_steps: |
| 288 | + assert limit_train_batches < min_steps |
304 | 289 |
|
| 290 | + class TestModel(BoringModel): |
305 | 291 | def training_step(self, batch, batch_idx): |
306 | | - output = self.layer(batch) |
307 | | - loss = self.loss(batch, output) |
308 | | - return {"loss": loss} |
309 | | - |
310 | | - def validation_step(self, batch, batch_idx): |
311 | | - return {"test_val_loss": self._loss_value} |
| 292 | + self.log("foo", batch_idx) |
| 293 | + return super().training_step(batch, batch_idx) |
312 | 294 |
|
313 | | - def validation_epoch_end(self, outputs): |
314 | | - _mean = np.mean([x["test_val_loss"] for x in outputs]) |
315 | | - if self.trainer.global_step <= self._step_freeze: |
316 | | - self._count_decrease += 1 |
317 | | - self._loss_value -= self._eps |
318 | | - self._values.append(_mean) |
319 | | - self.log("test_val_loss", _mean) |
320 | | - |
321 | | - model = Model(step_freeze) |
322 | | - model.training_step_end = None |
323 | | - model.test_dataloader = None |
324 | | - early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) |
| 295 | + es_callback = EarlyStopping("foo") |
325 | 296 | trainer = Trainer( |
326 | 297 | default_root_dir=tmpdir, |
327 | | - callbacks=[early_stop_callback], |
| 298 | + callbacks=es_callback, |
| 299 | + limit_val_batches=0, |
328 | 300 | limit_train_batches=limit_train_batches, |
329 | | - limit_val_batches=2, |
330 | | - min_steps=min_steps, |
331 | 301 | min_epochs=min_epochs, |
| 302 | + min_steps=min_steps, |
| 303 | + logger=False, |
| 304 | + enable_checkpointing=False, |
| 305 | + enable_progress_bar=False, |
| 306 | + enable_model_summary=False, |
332 | 307 | ) |
333 | | - trainer.fit(model) |
334 | | - |
335 | | - # Make sure loss was properly decreased |
336 | | - assert abs(original_loss_value - (model._count_decrease) * model._eps - model._loss_value) < 1e-6 |
337 | | - |
338 | | - pos_diff = (np.diff(model._values) == 0).nonzero()[0][0] |
339 | | - |
340 | | - # Compute when the latest validation epoch end happened |
341 | | - latest_validation_epoch_end = (pos_diff // limit_train_batches) * limit_train_batches |
342 | | - if pos_diff % limit_train_batches == 0: |
343 | | - latest_validation_epoch_end += limit_train_batches |
344 | | - |
345 | | - # Compute early stopping latest step |
346 | | - by_early_stopping = latest_validation_epoch_end + (1 + limit_train_batches) * patience |
347 | | - |
348 | | - # Compute min_epochs latest step |
349 | | - by_min_epochs = min_epochs * limit_train_batches |
| 308 | + model = TestModel() |
350 | 309 |
|
351 | | - # Make sure the trainer stops for the max of all minimum requirements |
352 | | - assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), ( |
353 | | - trainer.global_step, |
354 | | - max(min_steps, by_early_stopping, by_min_epochs), |
355 | | - step_freeze, |
356 | | - min_steps, |
357 | | - min_epochs, |
358 | | - ) |
| 310 | + expected_epochs = max(math.ceil(min_steps / limit_train_batches), min_epochs) |
| 311 | + # trigger early stopping directly after the first epoch |
| 312 | + side_effect = [(True, "")] * expected_epochs |
| 313 | + with mock.patch.object(es_callback, "_evaluate_stopping_criteria", side_effect=side_effect): |
| 314 | + trainer.fit(model) |
359 | 315 |
|
360 | | - _logger.disabled = False |
| 316 | + assert trainer.should_stop |
| 317 | + # epochs continue until min steps are reached |
| 318 | + assert trainer.current_epoch == expected_epochs |
| 319 | + # steps continue until min steps are reached AND the epoch is exhausted |
| 320 | + # stopping mid-epoch is not supported |
| 321 | + assert trainer.global_step == limit_train_batches * expected_epochs |
361 | 322 |
|
362 | 323 |
|
363 | 324 | def test_early_stopping_mode_options(): |
|
0 commit comments