Skip to content

Commit 1d189fb

Browse files
committed
no_warning_call
1 parent 5664204 commit 1d189fb

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

tests/helpers/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import functools
1515
import os
1616
import traceback
17+
from contextlib import contextmanager
18+
from typing import Optional
19+
20+
import pytest
1721

1822
from pytorch_lightning import seed_everything
1923
from pytorch_lightning.callbacks import ModelCheckpoint
@@ -111,3 +115,18 @@ def inner_f(queue, **kwargs):
111115
assert result == 1, 'expected 1, but returned %s' % result
112116

113117
return wrapper
118+
119+
120+
@contextmanager
121+
def no_warning_call(warning_type, match: Optional[str] = None):
122+
with pytest.warns(None) as record:
123+
yield
124+
125+
try:
126+
w = record.pop(warning_type)
127+
if not ((match and match in str(w.message)) or w):
128+
return
129+
except AssertionError:
130+
# no warning raised
131+
return
132+
raise AssertionError(f"`{warning_type}` was raised: {w}")

tests/overrides/test_data_parallel.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from pytorch_lightning.trainer.states import RunningStage
1515
from tests.helpers import BoringModel
16+
from tests.helpers.utils import no_warning_call
1617

1718

1819
@pytest.mark.parametrize("wrapper_class", [
@@ -65,25 +66,21 @@ def test_lightning_wrapper_module_warn_none_output(wrapper_class):
6566

6667
pl_module.automatic_optimization = False
6768

68-
with pytest.warns(None, match="Your training_step returned None") as record:
69+
with no_warning_call(UserWarning, match="Your training_step returned None"):
6970
pl_module.running_stage = RunningStage.TRAINING
7071
wrapped_module()
71-
assert not record
7272

73-
with pytest.warns(None, match="Your test_step returned None") as record:
73+
with no_warning_call(UserWarning, match="Your test_step returned None"):
7474
pl_module.running_stage = RunningStage.TESTING
7575
wrapped_module()
76-
assert not record
7776

78-
with pytest.warns(None, match="Your validation_step returned None") as record:
77+
with no_warning_call(UserWarning, match="Your validation_step returned None"):
7978
pl_module.running_stage = RunningStage.EVALUATING
8079
wrapped_module()
81-
assert not record
8280

83-
with pytest.warns(None, match="Your predict returned None") as record:
81+
with no_warning_call(UserWarning, match="Your predict returned None"):
8482
pl_module.running_stage = RunningStage.PREDICTING
8583
wrapped_module()
86-
assert not record
8784

8885
with pytest.warns(None) as record:
8986
pl_module.running_stage = None

0 commit comments

Comments
 (0)