@@ -48,38 +48,6 @@ def test_lightning_wrapper_module_methods(wrapper_class):
4848 pl_module .predict .assert_called_with (batch )
4949
5050
51- @pytest .mark .parametrize ("wrapper_class" , [
52- LightningParallelModule ,
53- LightningDistributedModule ,
54- ])
55- def test_lightning_wrapper_module_warn_none_output (wrapper_class ):
56- """ Test that the LightningWrapper module warns about forgotten return statement. """
57- warning_cache .clear ()
58- pl_module = MagicMock ()
59- wrapped_module = wrapper_class (pl_module )
60-
61- pl_module .training_step .return_value = None
62- pl_module .validation_step .return_value = None
63- pl_module .test_step .return_value = None
64-
65- with pytest .warns (UserWarning , match = "Your training_step returned None" ):
66- pl_module .running_stage = RunningStage .TRAINING
67- wrapped_module ()
68-
69- with pytest .warns (UserWarning , match = "Your test_step returned None" ):
70- pl_module .running_stage = RunningStage .TESTING
71- wrapped_module ()
72-
73- with pytest .warns (UserWarning , match = "Your validation_step returned None" ):
74- pl_module .running_stage = RunningStage .EVALUATING
75- wrapped_module ()
76-
77- with pytest .warns (None ) as record :
78- pl_module .running_stage = None
79- wrapped_module ()
80- assert not record
81-
82-
8351@pytest .mark .parametrize (
8452 "inp,expected" , [
8553 [torch .tensor (1.0 ), torch .tensor ([1.0 ])],
0 commit comments