Skip to content

Commit 40f8023

Browse files
authored
Remove trainer.fit return value [2/n] (#7237)
* `_fit_impl` refactor and types * Fix return * Remove return docstring * Fixes * Fixes * Remove `trainer.fit` return value * Update CHANGELOG * flake8 * Undo results change * Fix test * Revert changes for a separate PR * flake8
1 parent bdc4272 commit 40f8023

20 files changed

+35
-72
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
259259
- Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the `"log"/"progress_bar"` magic keys. Use `self.log` instead ([#6734](https://github.com/PyTorchLightning/pytorch-lightning/pull/6734))
260260

261261

262+
- Removed `trainer.fit()` return value of `1`. It has no return now ([#7237](https://github.com/PyTorchLightning/pytorch-lightning/pull/7237))
263+
264+
262265
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
263266

264267

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def _launch(
415415
train_dataloader: Any = None,
416416
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
417417
datamodule: Optional[LightningDataModule] = None,
418-
) -> Union[int, _EVALUATE_OUTPUT, _PREDICT_OUTPUT]:
418+
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
419419
# set local properties on the model
420420
self.model_connector.copy_trainer_model_properties(model)
421421

@@ -497,9 +497,7 @@ def _launch(
497497
self.state = TrainerState.FINISHED
498498
self._running_stage = None
499499

500-
# return 1 when finished
501-
# used for testing or when we need to know that training succeeded
502-
return self.accelerator.results or 1
500+
return self.accelerator.results
503501

504502
def pre_dispatch(self):
505503
self.accelerator.pre_dispatch(self)
@@ -836,7 +834,7 @@ def fit(
836834
train_dataloader: Any = None,
837835
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
838836
datamodule: Optional[LightningDataModule] = None,
839-
) -> Optional[int]:
837+
) -> None:
840838
r"""
841839
Runs the full optimization routine.
842840
@@ -857,15 +855,11 @@ def fit(
857855
self.state = TrainerState.FITTING
858856
self.training = True
859857

860-
results = self._launch(
861-
model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule
862-
)
858+
self._launch(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloaders, datamodule=datamodule)
863859

864860
assert self.state.stopped
865861
self.training = False
866862

867-
return results
868-
869863
def validate(
870864
self,
871865
model: Optional[LightningModule] = None,

pytorch_lightning/utilities/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def before_fit(self) -> None:
261261

262262
def fit(self) -> None:
263263
"""Runs fit of the instantiated trainer class and prepared fit keyword arguments"""
264-
self.fit_result = self.trainer.fit(**self.fit_kwargs)
264+
self.trainer.fit(**self.fit_kwargs)
265265

266266
def after_fit(self) -> None:
267267
"""Implement to run some code after fit has finished"""

tests/accelerators/test_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def test_evaluate(tmpdir, trainer_kwargs):
4444
**trainer_kwargs
4545
)
4646

47-
result = trainer.fit(model, datamodule=dm)
48-
assert result
47+
trainer.fit(model, datamodule=dm)
4948
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
5049

5150
old_weights = model.layer_0.weight.clone().detach().cpu()

tests/accelerators/test_tpu_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
9494
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
9595

9696
with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'):
97-
result = trainer.fit(model)
98-
assert result
97+
trainer.fit(model)
9998

10099

101100
# @RunIf(tpu=True)
@@ -106,8 +105,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
106105
# Ensure no warning for parameter mismatch is thrown.
107106
# """
108107

109-
# # TODO (kaushikb11): Add `paramter_validation` specific to
110-
# # TPU Accelerators
108+
# # TODO (kaushikb11): Add `parameter_validation` specific to TPU Accelerators
111109
# class Model(WeightSharingModule):
112110

113111
# def on_post_move_to_device(self):
@@ -117,8 +115,7 @@ def test_weight_tying_warning(tmpdir, capsys=None):
117115
# trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
118116

119117
# with pytest.warns(UserWarning) as warnings:
120-
# result = trainer.fit(model)
121-
# assert result
118+
# trainer.fit(model)
122119

123120
# assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
124121
# assert len(trainer.test(model)) == 1

tests/callbacks/test_callback_hook_outputs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def on_train_epoch_end(self, outputs) -> None:
6565

6666
assert any(isinstance(c, CB) for c in trainer.callbacks)
6767

68-
results = trainer.fit(model)
69-
assert results
68+
trainer.fit(model)
7069

7170

7271
def test_on_val_epoch_end_outputs(tmpdir):

tests/callbacks/test_prediction_writer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,18 @@ def write_on_epoch_end(self, *args, **kwargs):
4949

5050
cb = CustomPredictionWriter("batch_and_epoch")
5151
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
52-
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
52+
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
5353
assert cb.write_on_batch_end_called
5454
assert cb.write_on_epoch_end_called
55-
assert results == 1
5655

5756
cb = CustomPredictionWriter("batch")
5857
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
59-
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
58+
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
6059
assert cb.write_on_batch_end_called
6160
assert not cb.write_on_epoch_end_called
62-
assert results == 1
6361

6462
cb = CustomPredictionWriter("epoch")
6563
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
66-
results = trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
64+
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
6765
assert not cb.write_on_batch_end_called
6866
assert cb.write_on_epoch_end_called
69-
assert results == 1

tests/checkpointing/test_legacy_checkpoints.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,11 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
7676

7777
model = DummyModel.load_from_checkpoint(path_ckpt)
7878
trainer = Trainer(default_root_dir=tmpdir, max_epochs=6)
79-
result = trainer.fit(model)
80-
assert result
79+
trainer.fit(model)
8180

8281
# todo
8382
# model = DummyModel()
8483
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt)
85-
# result = trainer.fit(model)
86-
# assert result
84+
# trainer.fit(model)
8785

8886
sys.path = orig_sys_paths

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def configure_optimizers(self):
127127
max_epochs=max_epochs,
128128
progress_bar_refresh_rate=0,
129129
)
130-
results = trainer.fit(model)
131-
assert results
130+
trainer.fit(model)
132131
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
133132

134133
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
@@ -232,8 +231,7 @@ def configure_optimizers(self):
232231
progress_bar_refresh_rate=0,
233232
num_sanity_val_steps=0,
234233
)
235-
results = trainer.fit(model)
236-
assert results
234+
trainer.fit(model)
237235
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
238236

239237
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))

tests/core/test_datamodules.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,8 @@ def test_train_loop_only(tmpdir):
271271
)
272272

273273
# fit model
274-
result = trainer.fit(model, datamodule=dm)
274+
trainer.fit(model, datamodule=dm)
275275
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
276-
assert result
277276
assert trainer.callback_metrics['train_loss'] < 1.0
278277

279278

@@ -294,9 +293,8 @@ def test_train_val_loop_only(tmpdir):
294293
)
295294

296295
# fit model
297-
result = trainer.fit(model, datamodule=dm)
296+
trainer.fit(model, datamodule=dm)
298297
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
299-
assert result
300298
assert trainer.callback_metrics['train_loss'] < 1.0
301299

302300

@@ -353,10 +351,9 @@ def test_full_loop(tmpdir):
353351
)
354352

355353
# fit model
356-
result = trainer.fit(model, dm)
354+
trainer.fit(model, dm)
357355
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
358356
assert dm.trainer is not None
359-
assert result
360357

361358
# validate
362359
result = trainer.validate(datamodule=dm)

0 commit comments

Comments
 (0)