|
33 | 33 | from tests.helpers.datamodules import ClassifDataModule |
34 | 34 | from tests.helpers.runif import RunIf |
35 | 35 | from tests.helpers.simple_models import ClassificationModel |
36 | | -from tests.helpers.utils import no_warning_call |
| 36 | +from tests.loops.test_loops import CustomException |
37 | 37 |
|
38 | 38 |
|
39 | 39 | class ModelTrainerPropertyParity(Callback): |
@@ -774,44 +774,59 @@ def test_model_pickle(tmpdir): |
774 | 774 | cloudpickle.dumps(model) |
775 | 775 |
|
776 | 776 |
|
777 | | -@pytest.mark.parametrize("stop_batch_idx", [4, 7]) |
778 | | -def test_restarting_mid_epoch_raises_warning(tmpdir, stop_batch_idx): |
779 | | - """Test that a warning is raised if training is restarted from mid-epoch.""" |
| 777 | +class ExceptionModel(BoringModel): |
| 778 | + def __init__(self, stop_batch_idx): |
| 779 | + super().__init__() |
| 780 | + self.stop_batch_idx = stop_batch_idx |
780 | 781 |
|
781 | | - class CustomModel(BoringModel): |
782 | | - def __init__(self, stop_batch_idx): |
783 | | - super().__init__() |
784 | | - self.stop_batch_idx = stop_batch_idx |
| 782 | + def training_step(self, batch, batch_idx): |
| 783 | + if batch_idx == self.stop_batch_idx: |
| 784 | + raise CustomException() |
| 785 | + return super().training_step(batch, batch_idx) |
785 | 786 |
|
786 | | - def training_step(self, batch, batch_idx): |
787 | | - if (batch_idx + 1) == self.stop_batch_idx: |
788 | | - self.trainer.should_stop = True |
789 | 787 |
|
790 | | - return super().training_step(batch, batch_idx) |
| 788 | +class ShouldStopModel(ExceptionModel): |
| 789 | + def training_step(self, batch, batch_idx): |
| 790 | + if batch_idx == self.stop_batch_idx: |
| 791 | + # setting should_stop is treated differently to raising an exception. |
| 792 | + # checking both tests that this warning is raised in the correct loop |
| 793 | + self.trainer.should_stop = True |
| 794 | + return super().training_step(batch, batch_idx) |
791 | 795 |
|
792 | | - limit_train_batches = 7 |
| 796 | + |
| 797 | +@pytest.mark.parametrize("stop_in_the_middle", (True, False)) |
| 798 | +@pytest.mark.parametrize("model_cls", (ExceptionModel, ShouldStopModel)) |
| 799 | +def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls): |
| 800 | + """Test that a warning is raised if training is restarted from mid-epoch.""" |
| 801 | + limit_train_batches = 8 |
793 | 802 | trainer_kwargs = { |
794 | 803 | "default_root_dir": tmpdir, |
795 | 804 | "limit_train_batches": limit_train_batches, |
| 805 | + "limit_val_batches": 0, |
796 | 806 | "enable_progress_bar": False, |
797 | 807 | "enable_model_summary": False, |
798 | 808 | } |
799 | 809 | trainer = Trainer(max_epochs=1, **trainer_kwargs) |
800 | | - model = CustomModel(stop_batch_idx) |
801 | | - trainer.fit(model) |
| 810 | + model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1) |
| 811 | + |
| 812 | + if stop_in_the_middle: |
| 813 | + with pytest.raises(CustomException): |
| 814 | + trainer.fit(model) |
| 815 | + else: |
| 816 | + trainer.fit(model) |
802 | 817 |
|
803 | 818 | ckpt_path = str(tmpdir / "resume.ckpt") |
804 | 819 | trainer.save_checkpoint(ckpt_path) |
805 | 820 |
|
806 | | - trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) |
| 821 | + trainer = Trainer(max_epochs=2, **trainer_kwargs) |
| 822 | + model.stop_batch_idx = -1 |
807 | 823 |
|
808 | | - warning_raised = limit_train_batches != stop_batch_idx |
809 | | - context_manager = pytest.warns if warning_raised else no_warning_call |
| 824 | + context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call |
810 | 825 | with context_manager(UserWarning, match="resuming from a checkpoint that ended mid-epoch"): |
811 | 826 | trainer.fit(model, ckpt_path=ckpt_path) |
812 | 827 |
|
813 | | - if warning_raised: |
| 828 | + if stop_in_the_middle: |
814 | 829 | with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): |
815 | | - trainer = Trainer(max_epochs=2, limit_val_batches=0, **trainer_kwargs) |
816 | | - with no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"): |
| 830 | + trainer = Trainer(max_epochs=2, **trainer_kwargs) |
| 831 | + with tutils.no_warning_call(UserWarning, match="resuming from a checkpoint that ended mid-epoch"): |
817 | 832 | trainer.fit(model, ckpt_path=ckpt_path) |
0 commit comments