Skip to content

Commit c46ae62

Browse files
authored
Merge 45a010f into 4f391bc
2 parents 4f391bc + 45a010f commit c46ae62

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+497
-638
lines changed

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
1616

1717

18+
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
19+
20+
21+
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
22+
23+
1824
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
1925

2026

@@ -26,9 +32,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2632
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
2733

2834

35+
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
36+
37+
38+
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
39+
40+
2941
### Deprecated
3042

3143

44+
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
45+
46+
3247
### Removed
3348

3449
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytorch_lightning.core import LightningModule
2121
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2222
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
23+
from pytorch_lightning.trainer.states import TrainerState
2324
from pytorch_lightning.utilities.apply_func import move_data_to_device
2425
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
2526
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
@@ -80,8 +81,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
8081
def start_training(self, trainer: 'Trainer') -> None:
8182
self.training_type_plugin.start_training(trainer)
8283

83-
def start_testing(self, trainer: 'Trainer') -> None:
84-
self.training_type_plugin.start_testing(trainer)
84+
def start_evaluating(self, trainer: 'Trainer') -> None:
85+
self.training_type_plugin.start_evaluating(trainer)
8586

8687
def start_predicting(self, trainer: 'Trainer') -> None:
8788
self.training_type_plugin.start_predicting(trainer)
@@ -323,7 +324,7 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
323324
trainer: the Trainer, these optimizers should be connected to
324325
model: the model to be optimized by the created optimizers
325326
"""
326-
if trainer.testing:
327+
if trainer.state not in (TrainerState.FITTING, TrainerState.TUNING):
327328
return
328329
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
329330
trainer=trainer, model=self.lightning_module
@@ -417,7 +418,7 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I
417418
@property
418419
def results(self) -> Any:
419420
"""
420-
The results of the last training/testing run will be cached within the training type plugin.
421+
The results of the last run will be cached within the training type plugin.
421422
In distributed training, we make sure to transfer the results to the appropriate master process.
422423
"""
423424
return self.training_type_plugin.results

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,13 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):
137137
self.patience = callback_state['patience']
138138

139139
def on_validation_end(self, trainer, pl_module):
140-
if trainer.running_sanity_check:
140+
from pytorch_lightning.trainer.states import TrainerState
141+
if trainer.state != TrainerState.FITTING or trainer.sanity_checking:
141142
return
142143

143-
self._run_early_stopping_check(trainer, pl_module)
144+
self._run_early_stopping_check(trainer)
144145

145-
def _run_early_stopping_check(self, trainer, pl_module):
146+
def _run_early_stopping_check(self, trainer):
146147
"""
147148
Checks whether the early stopping condition is met
148149
and if so tells the trainer to stop the training.

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,14 @@ def save_checkpoint(self, trainer, pl_module):
213213
epoch = trainer.current_epoch
214214
global_step = trainer.global_step
215215

216+
from pytorch_lightning.trainer.states import TrainerState
216217
if (
217218
trainer.fast_dev_run # disable checkpointing with fast_dev_run
219+
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
220+
or trainer.sanity_checking # don't save anything during sanity check
218221
or self.save_top_k == 0 # no models are saved
219222
or self.period < 1 # no models are saved
220223
or (epoch + 1) % self.period # skip epoch
221-
or trainer.running_sanity_check # don't save anything during sanity check
222224
or self._last_global_step_saved == global_step # already saved at the last step
223225
):
224226
return

pytorch_lightning/callbacks/progress.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ def init_test_tqdm(self) -> tqdm:
380380
def on_sanity_check_start(self, trainer, pl_module):
381381
super().on_sanity_check_start(trainer, pl_module)
382382
self.val_progress_bar = self.init_sanity_tqdm()
383-
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
384383
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
385384

386385
def on_sanity_check_end(self, trainer, pl_module):
@@ -412,7 +411,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
412411

413412
def on_validation_start(self, trainer, pl_module):
414413
super().on_validation_start(trainer, pl_module)
415-
if not trainer.running_sanity_check:
414+
if trainer.sanity_checking:
415+
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
416+
else:
416417
self._update_bar(self.main_progress_bar) # fill up remaining
417418
self.val_progress_bar = self.init_validation_tqdm()
418419
reset(self.val_progress_bar, self.total_val_batches)

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from argparse import Namespace
2626
from functools import partial
2727
from pathlib import Path
28-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
28+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
2929

3030
import torch
3131
from torch import ScriptModule, Tensor
@@ -44,8 +44,6 @@
4444
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4545
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
4646

47-
if TYPE_CHECKING:
48-
from pytorch_lightning.trainer.states import RunningStage
4947
log = logging.getLogger(__name__)
5048

5149

@@ -69,7 +67,6 @@ class LightningModule(
6967
"on_gpu",
7068
"current_epoch",
7169
"global_step",
72-
"running_stage",
7370
"global_rank",
7471
"local_rank",
7572
"logger",
@@ -172,10 +169,6 @@ def automatic_optimization(self) -> bool:
172169
"""
173170
return self._automatic_optimization
174171

175-
@property
176-
def running_stage(self) -> Optional["RunningStage"]:
177-
return self.trainer._running_stage if self.trainer else None
178-
179172
@automatic_optimization.setter
180173
def automatic_optimization(self, automatic_optimization: bool) -> None:
181174
self._automatic_optimization = automatic_optimization

pytorch_lightning/overrides/base.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torch.nn.parallel import DistributedDataParallel
1919

2020
from pytorch_lightning.core.lightning import LightningModule
21-
from pytorch_lightning.trainer.states import RunningStage
2221
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
2322
from pytorch_lightning.utilities.warnings import WarningCache
2423

@@ -43,28 +42,28 @@ def __init__(self, pl_module: LightningModule):
4342
self.module = pl_module
4443

4544
def forward(self, *inputs, **kwargs):
46-
running_stage = self.module.running_stage
45+
trainer = self.module.trainer
4746

48-
if running_stage == RunningStage.TRAINING:
47+
if trainer and trainer.training:
4948
output = self.module.training_step(*inputs, **kwargs)
5049

5150
# In manual_optimization, we need to prevent DDP reducer as
5251
# it is done manually in ``LightningModule.manual_backward``
5352
# `require_backward_grad_sync` will be reset in the
5453
# ddp_plugin ``post_training_step`` hook
5554
if not self.module.automatic_optimization:
56-
self.module.trainer.model.require_backward_grad_sync = False
55+
trainer.model.require_backward_grad_sync = False
5756
warn_if_output_is_none(output, "training_step")
5857

59-
elif running_stage == RunningStage.TESTING:
58+
elif trainer and trainer.testing:
6059
output = self.module.test_step(*inputs, **kwargs)
6160
warn_if_output_is_none(output, "test_step")
6261

63-
elif running_stage == RunningStage.EVALUATING:
62+
elif trainer and (trainer.sanity_checking or trainer.validating):
6463
output = self.module.validation_step(*inputs, **kwargs)
6564
warn_if_output_is_none(output, "validation_step")
6665

67-
elif running_stage == RunningStage.PREDICTING:
66+
elif trainer and trainer.predicting:
6867
output = self.module.predict(*inputs, **kwargs)
6968
warn_if_output_is_none(output, "predict")
7069

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.overrides.distributed import prepare_for_backward
2828
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2929
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
30+
from pytorch_lightning.trainer.states import TrainerState
3031
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
3132
from pytorch_lightning.utilities.cloud_io import atomic_save
3233
from pytorch_lightning.utilities.cloud_io import load as pl_load
@@ -103,7 +104,7 @@ def start_training(self, trainer):
103104
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
104105
trainer.optimizers = []
105106

106-
def start_testing(self, trainer):
107+
def start_evaluating(self, trainer):
107108
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
108109

109110
def start_predicting(self, trainer):
@@ -152,7 +153,7 @@ def new_process(self, process_idx, trainer, mp_queue):
152153

153154
self.barrier()
154155

155-
results = trainer.train_or_test_or_predict()
156+
results = trainer.run_stage()
156157

157158
# persist info in ddp_spawn
158159
self.transfer_distrib_spawn_state_on_fit_end(results)
@@ -204,7 +205,6 @@ def on_save(self, checkpoint: dict) -> dict:
204205
return checkpoint
205206

206207
def transfer_distrib_spawn_state_on_fit_end(self, results):
207-
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
208208
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
209209
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
210210

@@ -213,8 +213,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
213213

214214
# save the last weights
215215
last_path = None
216-
# TODO: is there a better way than accessing trainer through model -> trainer?
217-
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
216+
if (
217+
self.lightning_module.trainer.state == TrainerState.FITTING
218+
and best_model_path is not None
219+
and len(best_model_path) > 0
220+
):
218221
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
219222
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
220223

@@ -224,14 +227,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
224227
self.mp_queue.put(results)
225228

226229
def __recover_child_process_weights(self, best_path, last_path):
227-
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
228230
# transfer back the best path to the trainer
229231
if self.lightning_module.trainer.checkpoint_callback:
230232
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
231233
# todo, pass also best score
232234

233235
# load last weights
234-
if last_path is not None and not self.lightning_module.trainer.testing:
236+
if last_path is not None and self.lightning_module.trainer.state == TrainerState.FITTING:
235237
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
236238
self.lightning_module.load_state_dict(ckpt)
237239

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def init_deepspeed(self):
213213
precision = self.lightning_module.trainer.accelerator.precision
214214
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
215215

216-
if self.lightning_module.trainer.training:
216+
if self.lightning_module.trainer and self.lightning_module.trainer.training:
217217
self._initialize_deepspeed_train(model)
218218
else:
219219
self._initialize_deepspeed_inference(model)
@@ -249,8 +249,7 @@ def _initialize_deepspeed_train(self, model):
249249
)
250250

251251
# set optimizer for save/load, but deepspeed manages the specific optimizer logic
252-
trainer = self.lightning_module.trainer
253-
trainer.optimizers = [optimizer]
252+
self.lightning_module.trainer.optimizers = [optimizer]
254253
self.model = model
255254

256255
def _initialize_deepspeed_inference(self, model):

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ def start_training(self, trainer):
101101
# Make sure all workers have finished training before returning to the user
102102
hvd.join()
103103

104-
def start_testing(self, trainer):
104+
def start_evaluating(self, trainer):
105105
with ExitStack():
106-
self._results = trainer.run_test()
106+
self._results = trainer.run_evaluate()
107107

108108
# Make sure all workers have finished training before returning to the user
109109
hvd.join()

0 commit comments

Comments
 (0)