|
13 | 13 | ) |
14 | 14 | from pytorch_lightning.trainer.states import RunningStage |
15 | 15 | from tests.helpers import BoringModel |
| 16 | +from tests.helpers.utils import no_warning_call |
16 | 17 |
|
17 | 18 |
|
18 | 19 | @pytest.mark.parametrize("wrapper_class", [ |
@@ -65,25 +66,21 @@ def test_lightning_wrapper_module_warn_none_output(wrapper_class): |
65 | 66 |
|
66 | 67 | pl_module.automatic_optimization = False |
67 | 68 |
|
68 | | - with pytest.warns(None, match="Your training_step returned None") as record: |
| 69 | + with no_warning_call(UserWarning, match="Your training_step returned None"): |
69 | 70 | pl_module.running_stage = RunningStage.TRAINING |
70 | 71 | wrapped_module() |
71 | | - assert not record |
72 | 72 |
|
73 | | - with pytest.warns(None, match="Your test_step returned None") as record: |
| 73 | + with no_warning_call(UserWarning, match="Your test_step returned None"): |
74 | 74 | pl_module.running_stage = RunningStage.TESTING |
75 | 75 | wrapped_module() |
76 | | - assert not record |
77 | 76 |
|
78 | | - with pytest.warns(None, match="Your validation_step returned None") as record: |
| 77 | + with no_warning_call(UserWarning, match="Your validation_step returned None"): |
79 | 78 | pl_module.running_stage = RunningStage.EVALUATING |
80 | 79 | wrapped_module() |
81 | | - assert not record |
82 | 80 |
|
83 | | - with pytest.warns(None, match="Your predict returned None") as record: |
| 81 | + with no_warning_call(UserWarning, match="Your predict returned None"): |
84 | 82 | pl_module.running_stage = RunningStage.PREDICTING |
85 | 83 | wrapped_module() |
86 | | - assert not record |
87 | 84 |
|
88 | 85 | with pytest.warns(None) as record: |
89 | 86 | pl_module.running_stage = None |
|
0 commit comments