|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import logging |
14 | 15 | from unittest.mock import patch |
15 | 16 |
|
16 | 17 | import pytest |
@@ -184,3 +185,36 @@ def test_no_val_on_train_epoch_loop_restart(tmpdir): |
184 | 185 | ) as advance_mocked: |
185 | 186 | trainer.fit(model, ckpt_path=ckpt_path) |
186 | 187 | assert advance_mocked.call_count == 1 |
| 188 | + |
| 189 | + |
| 190 | +@pytest.mark.parametrize( |
| 191 | + "min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg", |
| 192 | + [ |
| 193 | + (None, None, 1, 4, True, True, False), |
| 194 | + (None, None, 1, 10, True, True, False), |
| 195 | + (4, None, 1, 4, False, False, True), |
| 196 | + (4, 2, 1, 4, False, False, True), |
| 197 | + (4, None, 1, 10, False, True, False), |
| 198 | + (4, 3, 1, 3, False, False, True), |
| 199 | + (4, 10, 1, 10, False, True, False), |
| 200 | + (None, 4, 1, 4, True, True, False), |
| 201 | + ], |
| 202 | +) |
| 203 | +def test_should_stop_early_stopping_conditions_not_met( |
| 204 | + caplog, min_epochs, min_steps, current_epoch, global_step, early_stop, epoch_loop_done, raise_info_msg |
| 205 | +): |
| 206 | + """Test that checks that info message is logged when users sets `should_stop` but min conditions are not |
| 207 | + met.""" |
| 208 | + trainer = Trainer(min_epochs=min_epochs, min_steps=min_steps, limit_val_batches=0) |
| 209 | + trainer.num_training_batches = 10 |
| 210 | + trainer.should_stop = True |
| 211 | + trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = global_step |
| 212 | + trainer.fit_loop.epoch_loop.batch_progress.current.ready = global_step |
| 213 | + trainer.fit_loop.epoch_progress.current.completed = current_epoch - 1 |
| 214 | + |
| 215 | + message = f"min_epochs={min_epochs}` or `min_steps={min_steps}` has not been met. Training will continue" |
| 216 | + with caplog.at_level(logging.INFO, logger="pytorch_lightning.loops"): |
| 217 | + assert trainer.fit_loop.epoch_loop.done is epoch_loop_done |
| 218 | + |
| 219 | + assert (message in caplog.text) is raise_info_msg |
| 220 | + assert trainer.fit_loop._should_stop_early is early_stop |
0 commit comments