Skip to content

Commit facfda8

Browse files
rohitgr7awaelchli
andauthored
Remove no return warning from val/test step (#6139)
* remove warning * auto_opt * chlog * auto_opt * no_warning_call * rm old code * add warning for predict * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 217470b commit facfda8

File tree

14 files changed

+72
-102
lines changed

14 files changed

+72
-102
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2727
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
2828

2929

30+
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
31+
32+
3033
### Changed
3134

3235
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
@@ -49,6 +52,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4952
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
5053

5154

55+
- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
56+
57+
5258
- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))
5359

5460

pytorch_lightning/overrides/base.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any
15-
1614
import torch
1715
from torch.nn import DataParallel
1816
from torch.nn.parallel import DistributedDataParallel
1917

2018
from pytorch_lightning.core.lightning import LightningModule
2119
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
22-
from pytorch_lightning.utilities.warnings import WarningCache
23-
24-
warning_cache = WarningCache()
2520

2621

2722
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
@@ -53,20 +48,12 @@ def forward(self, *inputs, **kwargs):
5348
# ddp_plugin ``post_training_step`` hook
5449
if not self.module.automatic_optimization:
5550
trainer.model.require_backward_grad_sync = False
56-
warn_if_output_is_none(output, "training_step")
57-
5851
elif trainer and trainer.testing:
5952
output = self.module.test_step(*inputs, **kwargs)
60-
warn_if_output_is_none(output, "test_step")
61-
6253
elif trainer and (trainer.sanity_checking or trainer.validating):
6354
output = self.module.validation_step(*inputs, **kwargs)
64-
warn_if_output_is_none(output, "validation_step")
65-
6655
elif trainer and trainer.predicting:
6756
output = self.module.predict(*inputs, **kwargs)
68-
warn_if_output_is_none(output, "predict")
69-
7057
else:
7158
output = self.module(*inputs, **kwargs)
7259

@@ -76,12 +63,6 @@ def on_post_move_to_device(self):
7663
pass
7764

7865

79-
def warn_if_output_is_none(output: Any, method_name: str) -> None:
80-
""" Warns user about which method returned None. """
81-
if output is None:
82-
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
83-
84-
8566
def unwrap_lightning_module(wrapped_model) -> LightningModule:
8667
model = wrapped_model
8768
if isinstance(model, (DistributedDataParallel, DataParallel)):

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
214214
# save the last weights
215215
last_path = None
216216
if (
217-
self.lightning_module.trainer.state == TrainerState.FITTING
218-
and best_model_path is not None
217+
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
219218
and len(best_model_path) > 0
220219
):
221220
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
139139
# save the last weights
140140
last_path = None
141141
if (
142-
self.lightning_module.trainer.state == TrainerState.FITTING
143-
and best_model_path is not None
142+
self.lightning_module.trainer.state == TrainerState.FITTING and best_model_path is not None
144143
and len(best_model_path) > 0
145144
):
146145
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,7 @@ def get_evaluate_epoch_results(self):
297297

298298
# log results of evaluation
299299
if (
300-
self.trainer.state != TrainerState.FITTING
301-
and self.trainer.evaluating
302-
and self.trainer.is_global_zero
300+
self.trainer.state != TrainerState.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
303301
and self.trainer.verbose_evaluate
304302
):
305303
print('-' * 80)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ def get_evaluation_dataloaders(self):
6060
self.trainer.reset_val_dataloader(model)
6161
if self.trainer.sanity_checking:
6262
self.trainer.num_sanity_val_batches = [
63-
min(self.trainer.num_sanity_val_steps, val_batches)
64-
for val_batches in self.trainer.num_val_batches
63+
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
6564
]
6665
max_batches = self.trainer.num_sanity_val_batches
6766
else:

pytorch_lightning/trainer/predict_loop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515

1616
from pytorch_lightning.utilities.apply_func import apply_to_collection
17+
from pytorch_lightning.utilities.warnings import WarningCache
1718

1819

1920
class PredictLoop(object):
@@ -22,6 +23,7 @@ def __init__(self, trainer):
2223
self.trainer = trainer
2324
self.max_batches = None
2425
self.num_dataloaders = None
26+
self.warning_cache = WarningCache()
2527

2628
def on_trainer_init(self):
2729
self.trainer.num_predict_batches = []
@@ -74,6 +76,10 @@ def predict(self, batch, batch_idx, dataloader_idx):
7476

7577
model_ref._current_fx_name = "predict"
7678
predictions = self.trainer.accelerator.predict(args)
79+
80+
if predictions is None:
81+
self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")
82+
7783
self._predictions[dataloader_idx].append(predictions)
7884
self.trainer._progress_bar_callback.on_predict_batch_end(
7985
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,8 +878,7 @@ def test(
878878
# Attach datamodule to get setup/prepare_data added to model before the call to it below
879879
self.data_connector.attach_datamodule(model, datamodule)
880880
results = (
881-
self.__evaluate_given_model(model, dataloaders=test_dataloaders)
882-
if model_provided else
881+
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
883882
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
884883
)
885884

tests/helpers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def no_warning_call(warning_type, match: Optional[str] = None):
124124

125125
try:
126126
w = record.pop(warning_type)
127-
if not ((match and match in w.text) or w):
127+
if not (match and match in str(w.message)):
128128
return
129129
except AssertionError:
130130
# no warning raised

tests/overrides/test_data_parallel.py

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.nn import DataParallel
66

77
from pytorch_lightning.overrides import LightningDistributedModule
8-
from pytorch_lightning.overrides.base import warning_cache
98
from pytorch_lightning.overrides.data_parallel import (
109
LightningParallelModule,
1110
python_scalar_to_tensor,
@@ -20,12 +19,14 @@
2019
LightningParallelModule,
2120
LightningDistributedModule,
2221
])
23-
@pytest.mark.parametrize("stage", [
24-
("training", "training_step"),
25-
("testing", "test_step"),
26-
("validating", "validation_step"),
27-
("predicting", "predict"),
28-
])
22+
@pytest.mark.parametrize(
23+
"stage", [
24+
("training", "training_step"),
25+
("testing", "test_step"),
26+
("validating", "validation_step"),
27+
("predicting", "predict"),
28+
]
29+
)
2930
def test_lightning_wrapper_module_methods(wrapper_class, stage):
3031
""" Test that the LightningWrapper redirects .forward() to the LightningModule methods. """
3132
pl_module = MagicMock()
@@ -36,63 +37,14 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage):
3637

3738
prop, step = stage
3839
pl_module.trainer.sanity_checking = False
40+
3941
for p in ("training", "testing", "validating", "predicting"):
4042
setattr(pl_module.trainer, p, p == prop)
4143

4244
wrapped_module(batch, batch_idx)
43-
4445
getattr(pl_module, step).assert_called_with(batch, batch_idx)
4546

4647

47-
@pytest.mark.parametrize("wrapper_class", [
48-
LightningParallelModule,
49-
LightningDistributedModule,
50-
])
51-
@pytest.mark.parametrize("stage", [
52-
("training", "training_step"),
53-
("testing", "test_step"),
54-
("validating", "validation_step"),
55-
])
56-
def test_lightning_wrapper_module_warn_none_output(wrapper_class, stage):
57-
""" Test that the LightningWrapper module warns about forgotten return statement. """
58-
warning_cache.clear()
59-
pl_module = MagicMock()
60-
61-
prop, step = stage
62-
pl_module.trainer.sanity_checking = False
63-
for p in ("training", "testing", "validating", "predicting"):
64-
setattr(pl_module.trainer, p, p == prop)
65-
66-
wrapped_module = wrapper_class(pl_module)
67-
68-
getattr(pl_module, step).return_value = None
69-
70-
with pytest.warns(UserWarning, match=f"Your {step} returned None"):
71-
wrapped_module()
72-
73-
74-
@pytest.mark.parametrize("wrapper_class", [
75-
LightningParallelModule,
76-
LightningDistributedModule,
77-
])
78-
def test_lightning_wrapper_module_no_warn(wrapper_class):
79-
warning_cache.clear()
80-
pl_module = MagicMock()
81-
82-
pl_module.trainer.sanity_checking = False
83-
pl_module.trainer.training = False
84-
pl_module.trainer.testing = False
85-
pl_module.trainer.validating = False
86-
pl_module.trainer.predicting = False
87-
88-
wrapped_module = wrapper_class(pl_module)
89-
90-
with pytest.warns(None) as record:
91-
wrapped_module()
92-
pl_module.assert_called()
93-
assert not record
94-
95-
9648
@pytest.mark.parametrize(
9749
"inp,expected", [
9850
[torch.tensor(1.0), torch.tensor([1.0])],

0 commit comments

Comments
 (0)