From 9babf3f7ae297a76e7eb48ef6056674d27591bf6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 19 Jan 2021 18:14:04 +0000 Subject: [PATCH 01/43] start adding predict --- .../accelerators/cpu_accelerator.py | 3 + .../accelerators/ddp_accelerator.py | 3 + pytorch_lightning/core/lightning.py | 13 ++-- pytorch_lightning/overrides/data_parallel.py | 19 +++++- pytorch_lightning/trainer/evaluation_loop.py | 10 ++- pytorch_lightning/trainer/states.py | 2 + pytorch_lightning/trainer/trainer.py | 63 ++++++++++++++++++- tests/trainer/test_trainer.py | 31 ++++++++- 8 files changed, 133 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 7c80a4a30d223..de4e0a7d2fd14 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -79,6 +79,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def sync_tensor(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 56f6eaa2223a3..7eccc48a5abf5 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -164,6 +164,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dd5691d6e4553..72f38fea16cb9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,16 +14,16 @@ """nn.Module with additional great features.""" -from abc import ABC -from argparse import Namespace import collections import copy -from functools import partial import inspect import os -from pathlib import Path import re import tempfile +from abc import ABC +from argparse import Namespace +from functools import partial +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -980,6 +980,11 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ + def predict(self, *args, **kwargs): + """ + TODO: + """ + def configure_optimizers( self, ): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 69676cf77e079..abc41999d1277 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -28,6 +28,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -187,15 +189,26 @@ def __init__(self, pl_module: LightningModule): self.module = pl_module def forward(self, *inputs, **kwargs): - if self.module.training: + + if self.module.running_state == RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) warn_if_output_is_none(output, "training_step") - elif self.module.testing: + + elif self.module.running_state == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - else: + + elif self.module.running_state == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") + + elif self.module.running_state == RunningStage.PREDICTING: + output = self.module.predict(*inputs, **kwargs) + warn_if_output_is_none(output, "predict") + + else: + raise MisconfigurationException("running_stage shoud be define") + return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a8fa9f43684ca..c745969a143d7 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -14,6 +14,7 @@ import torch from pytorch_lightning.core.step_result import EvalResult, Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -169,7 +170,11 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() # run actual test step - if self.testing: + if self.trainer.running_stage == RunningStage.PREDICTING: + model_ref._current_fx_name = "predict" + output = self.trainer.accelerator_backend.predict(args) + self.outputs.append(output) + elif self.testing: model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) else: @@ -296,6 +301,9 @@ def __auto_reduce_result_objs(self, outputs): return eval_results + def prediction_epoch_end(self): + return [dl_idx for dl_idx in range(self.num_dataloaders)], [] + def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 6909a80251850..4112da357877e 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -43,9 +43,11 @@ class RunningStage(LightningEnum): >>> RunningStage.TRAINING == 'train' True """ + UNDEFINED = None TRAINING = 'train' EVALUATING = 'eval' TESTING = 'test' + PREDICTING = 'predict' TUNING = 'tune' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 59d745a2be816..d2e67b123a050 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner @@ -476,7 +476,6 @@ def fit( # ---------------------------- # hook self.call_hook('on_fit_start') - results = self.accelerator_backend.train() self.accelerator_backend.teardown() @@ -614,6 +613,8 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + if self.running_stage == RunningStage.PREDICTING: + continue output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -628,6 +629,9 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) + if self.running_stage == RunningStage.PREDICTING: + return self.evaluation_loop.prediction_epoch_end() + # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() @@ -767,6 +771,61 @@ def test( return results + def predict( + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + ): + r""" + + Separates from fit to make sure you never run on your test set until you want to. + + Args: + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the weights from the last epoch to test. Default to ``best``. + + model: The model to test. + + test_dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying inference samples. + + verbose: If True, prints the test results + + Returns: + The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + """ + + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_test = verbose + + if not test_dataloaders: + raise MisconfigurationException( + 'You need to pass test_dataloaders to trainer.predict. ' + ) + + if model is None: + raise MisconfigurationException( + 'You need to pass a model to trainer.predict .' + ) + + # attach data + if test_dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + + self.running_stage = RunningStage.PREDICTING + self.testing = True + self.model = model + results = self.fit(model) + self.running_stage = RunningStage.UNDEFINED + self.testing = False + self.teardown('test') + + return results + def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 97785d9e61a86..324a58868546f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -36,7 +36,7 @@ from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import BoringModel, EvalModelTemplate +from tests.base import BoringModel, EvalModelTemplate, RandomDataset @pytest.mark.parametrize("url_ckpt", [True, False]) @@ -1441,3 +1441,32 @@ def test_trainer_profiler_incorrect_arg_type(profiler): match=r"Only None, bool, str and subclasses of `BaseProfiler`" r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler) + + +def test_trainer_predict(tmpdir): + + class PredictModel(BoringModel): + + def predict(self, batch, batch_idx, dataloader_idx): + return self.layer(batch) + + def test_dataloader(self): + return [torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64))] + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 64)), + torch.utils.data.DataLoader(RandomDataset(32, 64))] + + model = PredictModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=0, + limit_val_batches=0, + limit_test_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + results = trainer.predict(model, dataloaders) + print(results) From f6261fa8ea70e437ca2a7f6c4eca7b1eac45ebdf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 19 Jan 2021 19:24:58 +0000 Subject: [PATCH 02/43] add predict --- .../accelerators/ddp2_accelerator.py | 3 +++ .../accelerators/ddp_cpu_spawn_accelerator.py | 3 +++ .../accelerators/ddp_hpc_accelerator.py | 3 +++ .../accelerators/ddp_spawn_accelerator.py | 3 +++ .../accelerators/dp_accelerator.py | 3 +++ .../accelerators/gpu_accelerator.py | 3 +++ .../accelerators/horovod_accelerator.py | 3 +++ .../accelerators/tpu_accelerator.py | 3 +++ pytorch_lightning/overrides/data_parallel.py | 8 ++++---- .../logger_connector/logger_connector.py | 17 +++++++---------- pytorch_lightning/trainer/evaluation_loop.py | 5 ++++- pytorch_lightning/trainer/trainer.py | 17 +++++++++++++---- pytorch_lightning/trainer/training_loop.py | 4 +++- tests/trainer/test_trainer.py | 15 +++++++-------- 14 files changed, 62 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index a5e8d720ce186..215ad625eae51 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -66,6 +66,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index b15b9e8062257..3dda3ac6ef465 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -180,6 +180,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index cf6aad9999223..b576841b3a829 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -83,6 +83,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index e23943e9262f8..a26db97ce84f2 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -214,6 +214,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 847d156d4f11d..dc5a6bacb0abf 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -134,6 +134,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def training_step_end(self, output): if isinstance(output, Result): output.dp_reduce() diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 2fe3b26679f5c..db39192dc512f 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -87,6 +87,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def to_device(self, batch): gpu_id = 0 if isinstance(self.trainer.data_parallel_device_ids, list): diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 150be86210866..bdb55a32e8a06 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -136,6 +136,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) optimizer.synchronize() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 66fc236a2a775..f1d502125aedc 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -159,6 +159,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index abc41999d1277..669b00b69b87b 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -190,19 +190,19 @@ def __init__(self, pl_module: LightningModule): def forward(self, *inputs, **kwargs): - if self.module.running_state == RunningStage.TRAINING: + if self.module.running_stage == RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) warn_if_output_is_none(output, "training_step") - elif self.module.running_state == RunningStage.TESTING: + elif self.module.running_stage == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif self.module.running_state == RunningStage.EVALUATING: + elif self.module.running_stage == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif self.module.running_state == RunningStage.PREDICTING: + elif self.module.running_stage == RunningStage.PREDICTING: output = self.module.predict(*inputs, **kwargs) warn_if_output_is_none(output, "predict") diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8e992f8f12034..ff9b468cba48b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -22,11 +22,12 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities import DeviceType, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.trainer.states import RunningStage class LoggerConnector: @@ -37,9 +38,8 @@ def __init__(self, trainer): self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} self._callback_hook_validator = CallbackHookNameValidator() - self._current_stage = None @property def callback_metrics(self) -> Dict: @@ -75,7 +75,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self._current_stage) # type: ignore + return self._cached_results.get(self.trainer.running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) @@ -90,10 +90,8 @@ def set_metrics(self, key: str, val: Any) -> None: metrics_holder = getattr(self, f"_{key}", None) metrics_holder.reset(val) - def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: - self._current_stage = LoggerStages.determine_stage(stage_or_testing) - if reset: - self.cached_results.reset() + def reset(self) -> None: + self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: self._callback_hook_validator.check_logging_in_callbacks( @@ -119,8 +117,7 @@ def on_train_batch_end(self) -> None: self.cached_results._batch_size = None def cache_logged_metrics(self): - if self._current_stage is not None: - self._cached_results[self._current_stage].cache_result() + self._cached_results[self.trainer.running_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c745969a143d7..d1eba9e20a9de 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -170,13 +170,16 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() # run actual test step + if self.trainer.running_stage == RunningStage.PREDICTING: model_ref._current_fx_name = "predict" output = self.trainer.accelerator_backend.predict(args) - self.outputs.append(output) + return output + elif self.testing: model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) + else: model_ref._current_fx_name = "validation_step" output = self.trainer.accelerator_backend.validation_step(args) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d2e67b123a050..c7b9dcb501b3c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -497,11 +497,18 @@ def fit( self._state = TrainerState.FINISHED return results or 1 + def _set_running_stage(self, stage): + # predicting is special and should override the others + if self.running_stage == RunningStage.PREDICTING: + stage = RunningStage.PREDICTING + self.get_model().running_stage = stage + self.running_stage = stage + def train(self): self.run_sanity_check(self.get_model()) # set stage for logging - self.logger_connector.set_stage("train") + self._set_running_stage(RunningStage.TRAINING) self.checkpoint_connector.has_trained = False @@ -563,7 +570,8 @@ def train(self): def run_evaluation(self, test_mode: bool = False, max_batches=None): # used to know if we are logging for val, test + reset cached results - self.logger_connector.set_stage(test_mode, reset=True) + self._set_running_stage(RunningStage.TESTING if test_mode else RunningStage.EVALUATING) + self.logger_connector.reset() # bookkeeping self.evaluation_loop.testing = test_mode @@ -750,8 +758,8 @@ def test( # SETUP HOOK # -------------------- self.verbose_test = verbose - - self.logger_connector.set_stage("test") + self.running_stage = RunningStage.TESTING + self._set_running_stage(RunningStage.TESTING) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -768,6 +776,7 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') + self.running_stage = RunningStage.UNDEFINED return results diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 78cb08f22161f..26f359c1c3aa9 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.warnings import WarningCache +from pytorch_lightning.trainer.states import RunningStage class TrainLoop: @@ -598,8 +599,9 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) + # reset stage to train - self.trainer.logger_connector.set_stage("train") + self.trainer._set_running_stage(RunningStage.TRAINING) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 324a58868546f..4dea0423bb741 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,6 +14,7 @@ import math import os import pickle +from pytorch_lightning.accelerators import accelerator import sys from argparse import Namespace from copy import deepcopy @@ -1442,17 +1443,13 @@ def test_trainer_profiler_incorrect_arg_type(profiler): r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler) +class PredictModel(BoringModel): -def test_trainer_predict(tmpdir): - - class PredictModel(BoringModel): + def predict(self, batch, batch_idx, dataloader_idx): + return self.layer(batch) - def predict(self, batch, batch_idx, dataloader_idx): - return self.layer(batch) - def test_dataloader(self): - return [torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64))] +def test_trainer_predict(tmpdir): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64))] @@ -1467,6 +1464,8 @@ def test_dataloader(self): max_epochs=1, log_every_n_steps=1, weights_summary=None, + gpus=2, + accelerator="ddp_spawn" ) results = trainer.predict(model, dataloaders) print(results) From 86aa7d413343d75dbec1caa0aac26fcfe15e7979 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 19 Jan 2021 19:30:12 +0000 Subject: [PATCH 03/43] resolve test --- pytorch_lightning/trainer/trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c7b9dcb501b3c..a9a072f4e1bf8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -475,8 +475,10 @@ def fit( # TRAIN # ---------------------------- # hook + self.running_stage = RunningStage.TRAINING self.call_hook('on_fit_start') results = self.accelerator_backend.train() + self.running_stage = RunningStage.UNDEFINED self.accelerator_backend.teardown() # ---------------------------- @@ -498,10 +500,13 @@ def fit( return results or 1 def _set_running_stage(self, stage): + model_ref = self.get_model() # predicting is special and should override the others if self.running_stage == RunningStage.PREDICTING: stage = RunningStage.PREDICTING - self.get_model().running_stage = stage + + if model_ref is not None: + model_ref.running_stage = stage self.running_stage = stage def train(self): From 4fb75d710996340f5dd3e8c737bbce59839fc972 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 08:45:54 +0000 Subject: [PATCH 04/43] add predict --- pytorch_lightning/core/lightning.py | 80 ++++++++++++++++++- .../logger_connector/logger_connector.py | 18 ++--- pytorch_lightning/trainer/evaluation_loop.py | 16 +++- pytorch_lightning/trainer/trainer.py | 46 +++++++---- pytorch_lightning/trainer/training_loop.py | 4 +- tests/trainer/test_trainer.py | 4 + 6 files changed, 138 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 72f38fea16cb9..2d4a9c3051ca8 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -981,8 +981,86 @@ def test_epoch_end(self, outputs): """ def predict(self, *args, **kwargs): + r""" + Operates on a single batch of data from the prediction set. + In this step you'd normally perform a forward and return the associated output. + + .. code-block:: python + + # the pseudocode for these calls + predictions = [] + for batch_idx, batch in enumerate(data): + out = predict(batch) + predictions.append(out) + predict_epoch_end(predictions) + + Args: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple test datasets used). + + Return: + + - A tensor or a list, tuple of dictionary containing tensors. + + .. code-block:: python + + # if you have one test dataloader: + def predict(self, batch, batch_idx) + + # if you have multiple test dataloaders: + def predict(self, batch, batch_idx, dataloader_idx) + + Examples: + .. code-block:: python + + def predict(self, batch, batch_idx): + x = batch + + # implement your own + out = self(x) + return out + + Note: + When the :meth:`predict` is called, the model has been put in eval mode and + PyTorch gradients have been disabled. """ - TODO: + + def predict_epoch_end( + self, outputs: List[Any] + ) -> None: + """ + Called at the end of a predict epoch with the output of all predict steps. + + .. code-block:: python + + # the pseudocode for these calls + predictions = [] + for batch_idx, batch in enumerate(data): + out = predict(batch) + predictions.append(out) + predict_epoch_end(predictions) + + Args: + outputs: List of outputs you defined in :meth:`predict`, or if there + are multiple dataloaders, a list containing a list of outputs for each dataloader + + Return: + Any + + Note: + If you didn't define a :meth:`predict`, this won't be called. + + Examples: + With a single dataloader: + + .. code-block:: python + + def predict_epoch_end(self, outputs): + assert len(outputs) == 1 + return outputs """ def configure_optimizers( diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 8e992f8f12034..669c59cb574f6 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -22,8 +22,9 @@ from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -37,9 +38,8 @@ def __init__(self, trainer): self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} self._callback_hook_validator = CallbackHookNameValidator() - self._current_stage = None @property def callback_metrics(self) -> Dict: @@ -75,7 +75,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self._current_stage) # type: ignore + return self._cached_results.get(self.trainer.running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) @@ -90,10 +90,8 @@ def set_metrics(self, key: str, val: Any) -> None: metrics_holder = getattr(self, f"_{key}", None) metrics_holder.reset(val) - def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: - self._current_stage = LoggerStages.determine_stage(stage_or_testing) - if reset: - self.cached_results.reset() + def reset(self) -> None: + self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: self._callback_hook_validator.check_logging_in_callbacks( @@ -119,8 +117,8 @@ def on_train_batch_end(self) -> None: self.cached_results._batch_size = None def cache_logged_metrics(self): - if self._current_stage is not None: - self._cached_results[self._current_stage].cache_result() + if self.trainer.running_stage: + self._cached_results[self.trainer.running_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c745969a143d7..b985f8a5002cc 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -136,6 +136,7 @@ def setup(self, model, max_batches, dataloaders): self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: @@ -173,10 +174,13 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): if self.trainer.running_stage == RunningStage.PREDICTING: model_ref._current_fx_name = "predict" output = self.trainer.accelerator_backend.predict(args) - self.outputs.append(output) + self._predictions[dataloader_idx].append(output) + return + elif self.testing: model_ref._current_fx_name = "test_step" output = self.trainer.accelerator_backend.test_step(args) + else: model_ref._current_fx_name = "validation_step" output = self.trainer.accelerator_backend.validation_step(args) @@ -301,8 +305,14 @@ def __auto_reduce_result_objs(self, outputs): return eval_results - def prediction_epoch_end(self): - return [dl_idx for dl_idx in range(self.num_dataloaders)], [] + def on_predict_epoch_end(self): + model_ref = self.trainer.get_model() + + results = self._predictions + if is_overridden('predict_epoch_end', model=model_ref): + results = model_ref.predict_epoch_end(results) + + return results, None def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d2e67b123a050..29e869d4d55e3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -415,6 +415,8 @@ def __init__( # last thing are the plugins which override whatever the trainer used by default self.plugin_connector.on_trainer_init(plugins) + self.running_stage = RunningStage.UNDEFINED + # Callback system self.on_init_end() @@ -442,6 +444,7 @@ def fit( """ # bookkeeping self._state = TrainerState.RUNNING + self._set_running_stage(RunningStage.TRAINING) # ---------------------------- # LINK DATA @@ -495,13 +498,27 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED + + self._set_running_stage(RunningStage.UNDEFINED) + return results or 1 + def _set_running_stage(self, stage): + model_ref = self.get_model() + # predicting is special and shouldn't be overriden + if self.running_stage == RunningStage.PREDICTING: + stage = RunningStage.PREDICTING + + if model_ref is not None: + model_ref.running_stage = stage + + self.running_stage = stage + def train(self): self.run_sanity_check(self.get_model()) # set stage for logging - self.logger_connector.set_stage("train") + self._set_running_stage(RunningStage.TRAINING) self.checkpoint_connector.has_trained = False @@ -563,7 +580,8 @@ def train(self): def run_evaluation(self, test_mode: bool = False, max_batches=None): # used to know if we are logging for val, test + reset cached results - self.logger_connector.set_stage(test_mode, reset=True) + self._set_running_stage(RunningStage.TESTING if test_mode else RunningStage.EVALUATING) + self.logger_connector.reset() # bookkeeping self.evaluation_loop.testing = test_mode @@ -630,7 +648,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): self.evaluation_loop.outputs.append(dl_outputs) if self.running_stage == RunningStage.PREDICTING: - return self.evaluation_loop.prediction_epoch_end() + return self.evaluation_loop.on_predict_epoch_end() # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() @@ -749,10 +767,9 @@ def test( # -------------------- # SETUP HOOK # -------------------- + self.running_stage = RunningStage.TESTING self.verbose_test = verbose - self.logger_connector.set_stage("test") - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( @@ -768,13 +785,14 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') + self.running_stage = RunningStage.UNDEFINED return results def predict( self, model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, ): @@ -788,7 +806,7 @@ def predict( model: The model to test. - test_dataloaders: Either a single + dataloaders: Either a single Pytorch Dataloader or a list of them, specifying inference samples. verbose: If True, prints the test results @@ -800,30 +818,30 @@ def predict( # -------------------- # SETUP HOOK # -------------------- + self.running_stage = RunningStage.PREDICTING self.verbose_test = verbose - if not test_dataloaders: + if not dataloaders: raise MisconfigurationException( - 'You need to pass test_dataloaders to trainer.predict. ' + 'You need to pass dataloaders to trainer.predict. ' ) if model is None: raise MisconfigurationException( - 'You need to pass a model to trainer.predict .' + 'You need to pass a model to trainer.predict. ' ) # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - self.running_stage = RunningStage.PREDICTING self.testing = True self.model = model results = self.fit(model) - self.running_stage = RunningStage.UNDEFINED self.testing = False self.teardown('test') + self.running_stage = RunningStage.UNDEFINED return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 78cb08f22161f..991df5c63b86a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -22,7 +22,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult, Result -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn @@ -599,7 +599,7 @@ def run_training_epoch(self): if should_check_val: self.trainer.run_evaluation(test_mode=False) # reset stage to train - self.trainer.logger_connector.set_stage("train") + self.trainer._set_running_stage(RunningStage.TRAINING) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 324a58868546f..9572d620b2767 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1450,6 +1450,10 @@ class PredictModel(BoringModel): def predict(self, batch, batch_idx, dataloader_idx): return self.layer(batch) + def predict_epoch_end(self, predictions): + assert len(predictions) == 2 + return predictions + def test_dataloader(self): return [torch.utils.data.DataLoader(RandomDataset(32, 64)), torch.utils.data.DataLoader(RandomDataset(32, 64))] From 2d7ee29190094fa94603f0b9254de3baeb39d576 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 09:42:56 +0000 Subject: [PATCH 05/43] remove limit_predict --- .../trainer/connectors/debugging_connector.py | 1 + tests/trainer/test_trainer.py | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 3a5447dd945b1..c32cfce6463a2 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -81,6 +81,7 @@ def determine_data_use_amount(self, overfit_batches: float) -> None: self.trainer.limit_train_batches = overfit_batches self.trainer.limit_val_batches = overfit_batches self.trainer.limit_test_batches = overfit_batches + self.trainer.limit_predict_batches = overfit_batches def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 57b299966b0cc..8cd81556cad89 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1453,16 +1453,13 @@ def predict_epoch_end(self, predictions): assert len(predictions) == 2 return predictions - dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 64)), - torch.utils.data.DataLoader(RandomDataset(32, 64))] + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), + torch.utils.data.DataLoader(RandomDataset(32, 2))] model = PredictModel() trainer = Trainer( default_root_dir=tmpdir, - limit_train_batches=0, - limit_val_batches=0, - limit_test_batches=2, max_epochs=1, log_every_n_steps=1, weights_summary=None, From 8b8d974da2bfa9ad6431d3eff1b77df00862f8ba Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 09:59:56 +0000 Subject: [PATCH 06/43] update --- .../accelerators/ddp2_accelerator.py | 5 -- .../accelerators/dp_accelerator.py | 5 -- pytorch_lightning/overrides/data_parallel.py | 11 +++-- tests/special_tests.sh | 2 + tests/trainer/test_trainer.py | 47 +++++++++++++++---- 5 files changed, 47 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index fd42b668932d7..98f12dd807ff5 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -98,11 +98,6 @@ def test_step_end(self, output): output.dp_reduce() return output - def predict_step_end(self, output): - def _reduce(o): - return o.mean(-1) - return apply_to_collection(output, torch.Tensor, _reduce) - def set_world_ranks(self, process_idx): # Todo: required argument `process_idx` is not used self.trainer.local_rank = self.trainer.node_rank diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 806702f38e225..b42cab8009b31 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -159,11 +159,6 @@ def test_step_end(self, output): output = output.mean() return output - def predict_step_end(self, output): - def _reduce(o): - return o.mean(-1) - return apply_to_collection(output, torch.Tensor, _reduce) - def reinit_scheduler_properties(self, optimizers: list, schedulers: list): """ Reinitialize optimizer.step properties added by schedulers diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index a143a68eb9944..ea1393c87f7e4 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -289,16 +289,19 @@ def _worker(i, module, input, kwargs, device=None): # --------------- # CHANGE - if module.training: + if module.running_stage == RunningStage.TRAINING: output = module.training_step(*input, **kwargs) fx_called = 'training_step' - elif module.testing: + elif module.running_stage == RunningStage.TESTING: output = module.test_step(*input, **kwargs) fx_called = 'test_step' - else: + elif module.running_stage == RunningStage.EVALUATING: output = module.validation_step(*input, **kwargs) fx_called = 'validation_step' - + elif module.running_stage == RunningStage.PREDICTING: + output = module.predict_step(*input, **kwargs) + fx_called = 'predict_step' + if output is None: warn_missing_output(fx_called) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ea14841c74bad..ae15bc57b1348 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -23,3 +23,5 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp +python ${DEFAULTS} pytest tests/trainer/test_trainer.py::test_trainer_predict_ddp +python ${DEFAULTS} pytest tests/trainer/test_trainer.py::test_trainer_predict_dp diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8cd81556cad89..908c37aeefef1 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,6 +14,7 @@ import math import os import pickle +from pytorch_lightning.accelerators.accelerator import Accelerator import sys from argparse import Namespace from copy import deepcopy @@ -1442,17 +1443,16 @@ def test_trainer_profiler_incorrect_arg_type(profiler): r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler) +class PredictModel(BoringModel): -def test_trainer_predict(tmpdir): - class PredictModel(BoringModel): + def predict_step(self, batch, batch_idx, dataloader_idx): + return self.layer(batch) - def predict_step(self, batch, batch_idx, dataloader_idx): - return self.layer(batch) - - def predict_epoch_end(self, predictions): - assert len(predictions) == 2 - return predictions + def predict_epoch_end(self, predictions): + assert len(predictions) == 2 + return predictions +def predict(tmpdir, accelerator, gpus, num_processes): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] @@ -1463,8 +1463,37 @@ def predict_epoch_end(self, predictions): max_epochs=1, log_every_n_steps=1, weights_summary=None, + accelerator=accelerator, + gpus=gpus, + num_processes=num_processes ) results = trainer.predict(model, dataloaders) + # todo: address this in another PR + num_samples = 1 if accelerator in ["ddp", "ddp_cpu", "ddp_spawn"] else 2 assert len(results) == 2 - assert len(results[0]) == 2 + assert len(results[0]) == num_samples assert results[0][0].shape == torch.Size([1, 2]) + + +def test_trainer_predict_cpu(tmpdir): + predict(tmpdir, None, None, None) + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_trainer_predict_dp(tmpdir): + predict(tmpdir, "dp", 2, None) + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_trainer_predict_ddp(tmpdir): + predict(tmpdir, "ddp", 2, None) + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_trainer_predict_ddp_spawn(tmpdir): + predict(tmpdir, "ddp_spawn", 2, None) + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") +def test_trainer_predict_1_gpu(tmpdir): + predict(tmpdir, None, 1, None) \ No newline at end of file From a8415c57183977c7a9e5e86275b24c03797b5f49 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 10:05:21 +0000 Subject: [PATCH 07/43] add test for predict --- pytorch_lightning/accelerators/ddp2_accelerator.py | 1 - pytorch_lightning/accelerators/dp_accelerator.py | 1 - pytorch_lightning/callbacks/progress.py | 2 +- pytorch_lightning/overrides/data_parallel.py | 5 ++++- tests/callbacks/test_progress_bar.py | 4 ++-- tests/trainer/test_trainer.py | 11 +++++++++-- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 98f12dd807ff5..9a11b7c9891c0 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -26,7 +26,6 @@ from pytorch_lightning.plugins.ddp_plugin import DDPPlugin from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index b42cab8009b31..081ac0e95a8ab 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -23,7 +23,6 @@ from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides.data_parallel import LightningDataParallel from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 32a581150f78f..acdc4439a6b6f 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -294,7 +294,7 @@ def init_validation_tqdm(self) -> tqdm: def init_test_tqdm(self, trainer=None) -> tqdm: """ Override this to customize the tqdm bar for testing. """ bar = tqdm( - desc=trainer.running_stage.name, + desc="TESTING" if trainer is None else trainer.running_stage.name, position=(2 * self.process_position), disable=self.is_disabled, leave=True, diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index ea1393c87f7e4..687964402acd2 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -292,16 +292,19 @@ def _worker(i, module, input, kwargs, device=None): if module.running_stage == RunningStage.TRAINING: output = module.training_step(*input, **kwargs) fx_called = 'training_step' + elif module.running_stage == RunningStage.TESTING: output = module.test_step(*input, **kwargs) fx_called = 'test_step' + elif module.running_stage == RunningStage.EVALUATING: output = module.validation_step(*input, **kwargs) fx_called = 'validation_step' + elif module.running_stage == RunningStage.PREDICTING: output = module.predict_step(*input, **kwargs) fx_called = 'predict_step' - + if output is None: warn_missing_output(fx_called) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8840dae54aea2..bd618e6f1033f 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -271,8 +271,8 @@ def init_validation_tqdm(self): bar = super().init_validation_tqdm() return self._mock_bar_update(bar) - def init_test_tqdm(self): - bar = super().init_test_tqdm() + def init_test_tqdm(self, trainer=None): + bar = super().init_test_tqdm(trainer=trainer) return self._mock_bar_update(bar) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 908c37aeefef1..c51c9df59822d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,7 +14,7 @@ import math import os import pickle -from pytorch_lightning.accelerators.accelerator import Accelerator +import platform import sys from argparse import Namespace from copy import deepcopy @@ -1443,6 +1443,7 @@ def test_trainer_profiler_incorrect_arg_type(profiler): r" are valid values for `Trainer`'s `profiler` parameter. *"): Trainer(profiler=profiler) + class PredictModel(BoringModel): def predict_step(self, batch, batch_idx, dataloader_idx): @@ -1452,6 +1453,7 @@ def predict_epoch_end(self, predictions): assert len(predictions) == 2 return predictions + def predict(tmpdir, accelerator, gpus, num_processes): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] @@ -1478,22 +1480,27 @@ def predict(tmpdir, accelerator, gpus, num_processes): def test_trainer_predict_cpu(tmpdir): predict(tmpdir, None, None, None) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") def test_trainer_predict_dp(tmpdir): predict(tmpdir, "dp", 2, None) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") def test_trainer_predict_ddp(tmpdir): predict(tmpdir, "ddp", 2, None) + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_trainer_predict_ddp_spawn(tmpdir): predict(tmpdir, "ddp_spawn", 2, None) + @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") def test_trainer_predict_1_gpu(tmpdir): - predict(tmpdir, None, 1, None) \ No newline at end of file + predict(tmpdir, None, 1, None) From c59a17ba824b5a129786f14ca7c59cdf8941405b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 10:07:38 +0000 Subject: [PATCH 08/43] typo --- pytorch_lightning/trainer/connectors/debugging_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index c32cfce6463a2..4c69b98562027 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -81,8 +81,6 @@ def determine_data_use_amount(self, overfit_batches: float) -> None: self.trainer.limit_train_batches = overfit_batches self.trainer.limit_val_batches = overfit_batches self.trainer.limit_test_batches = overfit_batches - self.trainer.limit_predict_batches = overfit_batches - def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: From 4bfbef118153e91a99e8e9f9bee5aa6949ac9b50 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 12:00:39 +0000 Subject: [PATCH 09/43] update on comments --- pytorch_lightning/accelerators/accelerator.py | 3 - .../accelerators/cpu_accelerator.py | 4 +- .../accelerators/ddp2_accelerator.py | 2 +- .../accelerators/ddp_accelerator.py | 2 +- .../accelerators/ddp_cpu_spawn_accelerator.py | 2 +- .../accelerators/ddp_hpc_accelerator.py | 2 +- .../accelerators/ddp_spawn_accelerator.py | 2 +- .../accelerators/dp_accelerator.py | 2 +- .../accelerators/gpu_accelerator.py | 4 +- .../accelerators/horovod_accelerator.py | 4 +- .../accelerators/tpu_accelerator.py | 4 +- pytorch_lightning/core/lightning.py | 136 ------------------ pytorch_lightning/overrides/data_parallel.py | 7 +- pytorch_lightning/trainer/evaluation_loop.py | 13 +- tests/trainer/test_trainer.py | 6 + 15 files changed, 32 insertions(+), 161 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 854bf714e128b..1b3ae6f23058a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -83,9 +83,6 @@ def test_step_end(self, output): def validation_step_end(self, output): return output - def predict_step_end(self, output): - return output - def process_dataloader(self, dataloader): return dataloader diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 5567346504407..7033e217e34bb 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -79,8 +79,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def predict_step(self, args): - return self._step(self.trainer.model.predict_step, args) + def forward(self, args): + return self._step(self.trainer.model.forward, args) def sync_tensor(self, tensor: Union[torch.Tensor], diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 9a11b7c9891c0..9a701a4341b7b 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -66,7 +66,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 4a9d5764faa84..fa2db84865e7a 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -164,7 +164,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 73e066defc6ec..1ebb920e45952 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -180,7 +180,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index de264f732d66a..4797fd17af2be 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -83,7 +83,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 6941d72023f4c..93fcc1d1dc0f3 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -214,7 +214,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 081ac0e95a8ab..641dd9ee7da15 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -134,7 +134,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def predict_step(self, args): + def forward(self, args): return self._step(args) def training_step_end(self, output): diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index c9b76d27c0e22..8e0922d2f64fe 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -87,8 +87,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def predict_step(self, args): - return self._step(self.trainer.model.predict_step, args) + def forward(self, args): + return self._step(self.trainer.model.forward, args) def to_device(self, batch): gpu_id = 0 diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 3d6b86b53faf1..a16d882d8e64a 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -136,8 +136,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def predict_step(self, args): - return self._step(self.trainer.model.predict_step, args) + def forward(self, args): + return self._step(self.trainer.model.forward, args) def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 74a17981abdde..06befbcf0489a 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -159,8 +159,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def predict_step(self, args): - return self._step(self.trainer.model.predict_step, args) + def forward(self, args): + return self._step(self.trainer.model.forward, args) def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 8ea9eb9665736..cd486bad3180a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -927,142 +927,6 @@ def test_step_end(self, output_results): See the :ref:`multi_gpu` guide for more details. """ - def test_epoch_end( - self, outputs: List[Any] - ) -> None: - """ - Called at the end of a test epoch with the output of all test steps. - - .. code-block:: python - - # the pseudocode for these calls - test_outs = [] - for test_batch in test_data: - out = test_step(test_batch) - test_outs.append(out) - test_epoch_end(test_outs) - - Args: - outputs: List of outputs you defined in :meth:`test_step_end`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader - - Return: - None - - Note: - If you didn't define a :meth:`test_step`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def test_epoch_end(self, outputs): - # do something with the outputs of all test batches - all_test_preds = test_step_outputs.predictions - - some_result = calc_all_results(all_test_preds) - self.log(some_result) - - With multiple dataloaders, `outputs` will be a list of lists. The outer list contains - one entry per dataloader, while the inner list contains the individual outputs of - each test step for that dataloader. - - .. code-block:: python - - def test_epoch_end(self, outputs): - final_value = 0 - for dataloader_outputs in outputs: - for test_step_out in dataloader_outputs: - # do something - final_value += test_step_out - - self.log('final_metric', final_value) - """ - - def predict_step(self, *args, **kwargs): - r""" - Operates on a single batch of data from the prediction set. - In this step you'd normally perform a forward and return the associated output. - - .. code-block:: python - - # the pseudocode for these calls - predictions = [] - for batch_idx, batch in enumerate(data): - out = predict_step(batch) - predictions.append(out) - predict_epoch_end(predictions) - - Args: - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): The index of this batch. - dataloader_idx (int): The index of the dataloader that produced this batch - (only if multiple test datasets used). - - Return: - - - A tensor or a list, tuple of dictionary containing tensors. - - .. code-block:: python - - # if you have one test dataloader: - def predict_step(self, batch, batch_idx) - - # if you have multiple test dataloaders: - def predict_step(self, batch, batch_idx, dataloader_idx) - - Examples: - .. code-block:: python - - def predict_step(self, batch, batch_idx): - x = batch - - # implement your own - out = self(x) - return out - - Note: - When the :meth:`predict` is called, the model has been put in eval mode and - PyTorch gradients have been disabled. - """ - - def predict_epoch_end( - self, outputs: List[Any] - ) -> None: - """ - Called at the end of a predict epoch with the output of all predict steps. - - .. code-block:: python - - # the pseudocode for these calls - predictions = [] - for batch_idx, batch in enumerate(data): - out = predict_step(batch) - predictions.append(out) - predict_epoch_end(predictions) - - Args: - outputs: List of outputs you defined in :meth:`predict`, or if there - are multiple dataloaders, a list containing a list of outputs for each dataloader - - Return: - Any - - Note: - If you didn't define a :meth:`predict`, this won't be called. - - Examples: - With a single dataloader: - - .. code-block:: python - - def predict_epoch_end(self, outputs): - assert len(outputs) == 1 - return outputs - """ - def configure_optimizers( self, ): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 687964402acd2..86e293701afa0 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -203,8 +203,7 @@ def forward(self, *inputs, **kwargs): warn_if_output_is_none(output, "validation_step") elif self.module.running_stage == RunningStage.PREDICTING: - output = self.module.predict_step(*inputs, **kwargs) - warn_if_output_is_none(output, "predict") + output = self.module(*inputs, **kwargs) else: raise MisconfigurationException("running_stage shoud be define") @@ -302,8 +301,8 @@ def _worker(i, module, input, kwargs, device=None): fx_called = 'validation_step' elif module.running_stage == RunningStage.PREDICTING: - output = module.predict_step(*input, **kwargs) - fx_called = 'predict_step' + output = module(*input, **kwargs) + fx_called = 'forward' if output is None: warn_missing_output(fx_called) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 0cfdea94f883f..3e5f722adefcf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.utilities.apply_func import apply_to_collection import torch from pytorch_lightning.core.step_result import EvalResult, Result @@ -174,11 +175,10 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): if self.trainer.running_stage == RunningStage.PREDICTING: model_ref._current_fx_name = "predict" - predict_step_output = self.trainer.accelerator_backend.predict_step(args) - predict_step_end_output = self.trainer.call_hook("predict_step_end", predict_step_output) - self._predictions[dataloader_idx].append(predict_step_end_output) + forward_output = self.trainer.accelerator_backend.forward([args[0]]) + self._predictions[dataloader_idx].append(forward_output) self.trainer._progress_bar_callback.on_test_batch_end( - self.trainer, model_ref, predict_step_end_output, batch, batch_idx, dataloader_idx) + self.trainer, model_ref, forward_output, batch, batch_idx, dataloader_idx) return elif self.testing: @@ -316,6 +316,11 @@ def on_predict_epoch_end(self): if is_overridden('predict_epoch_end', model=model_ref): results = model_ref.predict_epoch_end(results) + def _convert_to_numpy(v): + return v.cpu().numpy() + + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) + return results, None def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c51c9df59822d..77fb7bf73ce7a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1455,6 +1455,7 @@ def predict_epoch_end(self, predictions): def predict(tmpdir, accelerator, gpus, num_processes): + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] @@ -1504,3 +1505,8 @@ def test_trainer_predict_ddp_spawn(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") def test_trainer_predict_1_gpu(tmpdir): predict(tmpdir, None, 1, None) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +def test_trainer_predict_ddp_cpu(tmpdir): + predict(tmpdir, "ddp_cpu", 0, 2) \ No newline at end of file From 5184e56ed0fa08979f9d6bfa8dd98a608e6ade08 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 12:02:40 +0000 Subject: [PATCH 10/43] remove predict_step --- pytorch_lightning/trainer/connectors/debugging_connector.py | 1 + tests/trainer/test_trainer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 4c69b98562027..3a5447dd945b1 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -82,6 +82,7 @@ def determine_data_use_amount(self, overfit_batches: float) -> None: self.trainer.limit_val_batches = overfit_batches self.trainer.limit_test_batches = overfit_batches + def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: return batches diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 77fb7bf73ce7a..caeb3e5238424 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1509,4 +1509,4 @@ def test_trainer_predict_1_gpu(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_trainer_predict_ddp_cpu(tmpdir): - predict(tmpdir, "ddp_cpu", 0, 2) \ No newline at end of file + predict(tmpdir, "ddp_cpu", 0, 2) From 4d5f57d05719fbfb473d5957a79509f1b8caf9ee Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 12:07:10 +0000 Subject: [PATCH 11/43] update ddp_shareded --- pytorch_lightning/overrides/fairscale.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 8297d708b26c2..a6050d2f4cfe3 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE LightningShardedDataParallel = None @@ -23,10 +25,21 @@ def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: self.sync_buffers() - if self.module.training: + if self.module.running_stage == RunningStage.TRAINING: outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: + + elif self.module.running_stage == RunningStage.TESTING: outputs = self.module.test_step(*inputs, **kwargs) - else: + + elif self.module.running_stage == RunningStage.EVALUATING: outputs = self.module.validation_step(*inputs, **kwargs) + + elif self.module.running_stage == RunningStage.PREDICTING: + outputs = self.module(*inputs, **kwargs) + + else: + raise MisconfigurationException( + "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") + + return outputs From 7fa90c85bd3daf65baa52641b4d01e02ff31c56d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 12:15:14 +0000 Subject: [PATCH 12/43] check ddp_sharded --- tests/trainer/test_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index caeb3e5238424..5f0fb5ff84aa9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1454,7 +1454,7 @@ def predict_epoch_end(self, predictions): return predictions -def predict(tmpdir, accelerator, gpus, num_processes): +def predict(tmpdir, accelerator, gpus, num_processes, plugins=None): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] @@ -1468,7 +1468,8 @@ def predict(tmpdir, accelerator, gpus, num_processes): weights_summary=None, accelerator=accelerator, gpus=gpus, - num_processes=num_processes + num_processes=num_processes, + plugins=plugins ) results = trainer.predict(model, dataloaders) # todo: address this in another PR @@ -1492,8 +1493,9 @@ def test_trainer_predict_dp(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") -def test_trainer_predict_ddp(tmpdir): - predict(tmpdir, "ddp", 2, None) +@pytest.mark.parametrize('plugins', [None, "ddp_sharded"]) +def test_trainer_predict_ddp(tmpdir, plugins): + predict(tmpdir, "ddp", 2, None, plugins=plugins) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From 392e04fe743bff15ad59c05f3c3967a4e1aff433 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 12:21:01 +0000 Subject: [PATCH 13/43] resolve on comments --- pytorch_lightning/core/lightning.py | 40 ++++++++++++++++++++++++ pytorch_lightning/overrides/fairscale.py | 7 ++--- pytorch_lightning/trainer/trainer.py | 6 +--- tests/trainer/test_trainer.py | 12 +------ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cd486bad3180a..e2f9974352e86 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -927,6 +927,46 @@ def test_step_end(self, output_results): See the :ref:`multi_gpu` guide for more details. """ + def test_epoch_end( + self, outputs: List[Any] + ) -> None: + """ + Called at the end of a test epoch with the output of all test steps. + .. code-block:: python + # the pseudocode for these calls + test_outs = [] + for test_batch in test_data: + out = test_step(test_batch) + test_outs.append(out) + test_epoch_end(test_outs) + Args: + outputs: List of outputs you defined in :meth:`test_step_end`, or if there + are multiple dataloaders, a list containing a list of outputs for each dataloader + Return: + None + Note: + If you didn't define a :meth:`test_step`, this won't be called. + Examples: + With a single dataloader: + .. code-block:: python + def test_epoch_end(self, outputs): + # do something with the outputs of all test batches + all_test_preds = test_step_outputs.predictions + some_result = calc_all_results(all_test_preds) + self.log(some_result) + With multiple dataloaders, `outputs` will be a list of lists. The outer list contains + one entry per dataloader, while the inner list contains the individual outputs of + each test step for that dataloader. + .. code-block:: python + def test_epoch_end(self, outputs): + final_value = 0 + for dataloader_outputs in outputs: + for test_step_out in dataloader_outputs: + # do something + final_value += test_step_out + self.log('final_metric', final_value) + """ + def configure_optimizers( self, ): diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index a6050d2f4cfe3..86a3541b74e9a 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: @@ -27,7 +27,7 @@ def forward(self, *inputs, **kwargs): if self.module.running_stage == RunningStage.TRAINING: outputs = self.module.training_step(*inputs, **kwargs) - + elif self.module.running_stage == RunningStage.TESTING: outputs = self.module.test_step(*inputs, **kwargs) @@ -40,6 +40,5 @@ def forward(self, *inputs, **kwargs): else: raise MisconfigurationException( "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") - - + return outputs diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1c30215456f30..4b7f241fd4351 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -794,17 +794,13 @@ def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ckpt_path: Optional[str] = 'best', verbose: bool = True, ): r""" - Separates from fit to make sure you never run on your test set until you want to. + Separates from fit to make sure you never run on your predictions set until you want to. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. - model: The model to test. dataloaders: Either a single diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5f0fb5ff84aa9..06fce3e6cbfcc 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1444,22 +1444,12 @@ def test_trainer_profiler_incorrect_arg_type(profiler): Trainer(profiler=profiler) -class PredictModel(BoringModel): - - def predict_step(self, batch, batch_idx, dataloader_idx): - return self.layer(batch) - - def predict_epoch_end(self, predictions): - assert len(predictions) == 2 - return predictions - - def predict(tmpdir, accelerator, gpus, num_processes, plugins=None): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - model = PredictModel() + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, From 036e24d523994e55c7e62cb25b2de06bf5476d07 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 12:23:57 +0000 Subject: [PATCH 14/43] resolve isort --- pytorch_lightning/loggers/wandb.py | 3 ++- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 15b46a4d9bbd1..443c50ac4835f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -29,8 +29,9 @@ _WANDB_AVAILABLE = _module_available("wandb") try: - import wandb from wandb.wandb_run import Run + + import wandb except ImportError: # needed for test mocks, these tests shall be updated wandb, Run = None, None diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 3e5f722adefcf..e4b3f7acfb9d1 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.apply_func import apply_to_collection import torch from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache From 4b5525ba8a2b03a5cffda77bb9bd2d5614acb815 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 12:30:01 +0000 Subject: [PATCH 15/43] update dp --- pytorch_lightning/overrides/data_parallel.py | 25 ++++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 86e293701afa0..88f7ed3dcba8b 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -80,14 +80,24 @@ def forward(self, *inputs, **kwargs): "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: - # lightning - if self.module.training: + + if self.module.running_stage == RunningStage.TRAINING: return self.module.training_step(*inputs[0], **kwargs[0]) - if self.module.testing: + + elif self.module.running_stage == RunningStage.TESTING: return self.module.test_step(*inputs[0], **kwargs[0]) - return self.module.validation_step(*inputs[0], **kwargs[0]) + elif self.module.running_stage == RunningStage.EVALUATING: + return self.module.validation_step(*inputs[0], **kwargs[0]) + + elif self.module.running_stage == RunningStage.PREDICTING: + return self.module(*inputs[0], **kwargs[0]) + + else: + raise MisconfigurationException( + "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) @@ -206,7 +216,8 @@ def forward(self, *inputs, **kwargs): output = self.module(*inputs, **kwargs) else: - raise MisconfigurationException("running_stage shoud be define") + raise MisconfigurationException( + "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") return output @@ -304,6 +315,10 @@ def _worker(i, module, input, kwargs, device=None): output = module(*input, **kwargs) fx_called = 'forward' + else: + raise MisconfigurationException( + "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") + if output is None: warn_missing_output(fx_called) From 97fa5b389a5359fa25784bb7825053f47d080d1e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 12:32:29 +0000 Subject: [PATCH 16/43] add test dp 1 gpu --- tests/trainer/test_trainer.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 06fce3e6cbfcc..08bf4df42cea5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1444,12 +1444,22 @@ def test_trainer_profiler_incorrect_arg_type(profiler): Trainer(profiler=profiler) +class PredictModel(BoringModel): + + def predict_step(self, batch, batch_idx, dataloader_idx): + return self.layer(batch) + + def predict_epoch_end(self, predictions): + assert len(predictions) == 2 + return predictions + + def predict(tmpdir, accelerator, gpus, num_processes, plugins=None): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - model = BoringModel() + model = PredictModel() trainer = Trainer( default_root_dir=tmpdir, @@ -1476,8 +1486,9 @@ def test_trainer_predict_cpu(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") -def test_trainer_predict_dp(tmpdir): - predict(tmpdir, "dp", 2, None) +@pytest.mark.parametrize('num_gpus', [1, 2]) +def test_trainer_predict_dp(tmpdir, num_gpus): + predict(tmpdir, "dp", num_gpus, None) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From b6ed16383cae30d5e7a13ed990b50e50fb7f8ba7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 14:28:03 +0000 Subject: [PATCH 17/43] made default forward --- pytorch_lightning/callbacks/progress.py | 2 +- pytorch_lightning/overrides/data_parallel.py | 35 +++++++------------ pytorch_lightning/overrides/fairscale.py | 15 ++++---- .../logger_connector/epoch_result_store.py | 26 ++------------ .../logger_connector/logger_connector.py | 4 +-- pytorch_lightning/trainer/evaluation_loop.py | 7 +--- pytorch_lightning/trainer/trainer.py | 26 +++++++------- tests/overrides/test_data_parallel.py | 10 +++--- tests/trainer/test_trainer.py | 5 +-- 9 files changed, 46 insertions(+), 84 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index acdc4439a6b6f..2a0ea85c4f959 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -294,7 +294,7 @@ def init_validation_tqdm(self) -> tqdm: def init_test_tqdm(self, trainer=None) -> tqdm: """ Override this to customize the tqdm bar for testing. """ bar = tqdm( - desc="TESTING" if trainer is None else trainer.running_stage.name, + desc="Predicting" if trainer.is_predicting else "Testing", position=(2 * self.process_position), disable=self.is_disabled, leave=True, diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 88f7ed3dcba8b..dffc6eae01328 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -29,7 +29,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import WarningCache @@ -83,21 +82,19 @@ def forward(self, *inputs, **kwargs): if len(self.device_ids) == 1: - if self.module.running_stage == RunningStage.TRAINING: + running_stage = getattr(self.module, "running_stage") + + if running_stage == RunningStage.TRAINING: return self.module.training_step(*inputs[0], **kwargs[0]) - elif self.module.running_stage == RunningStage.TESTING: + elif running_stage == RunningStage.TESTING: return self.module.test_step(*inputs[0], **kwargs[0]) - elif self.module.running_stage == RunningStage.EVALUATING: + elif running_stage == RunningStage.EVALUATING: return self.module.validation_step(*inputs[0], **kwargs[0]) - elif self.module.running_stage == RunningStage.PREDICTING: - return self.module(*inputs[0], **kwargs[0]) - else: - raise MisconfigurationException( - "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") + return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) @@ -200,24 +197,22 @@ def __init__(self, pl_module: LightningModule): def forward(self, *inputs, **kwargs): - if self.module.running_stage == RunningStage.TRAINING: + running_stage = getattr(self.module, "running_stage") + + if running_stage == RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) warn_if_output_is_none(output, "training_step") - elif self.module.running_stage == RunningStage.TESTING: + elif running_stage == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - elif self.module.running_stage == RunningStage.EVALUATING: + elif running_stage == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") - elif self.module.running_stage == RunningStage.PREDICTING: - output = self.module(*inputs, **kwargs) - else: - raise MisconfigurationException( - "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") + output = self.module(*inputs, **kwargs) return output @@ -311,14 +306,10 @@ def _worker(i, module, input, kwargs, device=None): output = module.validation_step(*input, **kwargs) fx_called = 'validation_step' - elif module.running_stage == RunningStage.PREDICTING: + else: output = module(*input, **kwargs) fx_called = 'forward' - else: - raise MisconfigurationException( - "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") - if output is None: warn_missing_output(fx_called) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 86a3541b74e9a..65705a62e2310 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: @@ -25,20 +24,18 @@ def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: self.sync_buffers() - if self.module.running_stage == RunningStage.TRAINING: + running_stage = getattr(self.module, "running_stage") + + if running_stage == RunningStage.TRAINING: outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.running_stage == RunningStage.TESTING: + elif running_stage == RunningStage.TESTING: outputs = self.module.test_step(*inputs, **kwargs) - elif self.module.running_stage == RunningStage.EVALUATING: + elif running_stage == RunningStage.EVALUATING: outputs = self.module.validation_step(*inputs, **kwargs) - elif self.module.running_stage == RunningStage.PREDICTING: - outputs = self.module(*inputs, **kwargs) - else: - raise MisconfigurationException( - "running_stage should either be [TRAINING, TESTING, EVALUATING, PREDICTING]") + outputs = self.module(*inputs, **kwargs) return outputs diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2e6053d74b55f..0711a3fb3a25e 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -12,35 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DistributedType, LightningEnum -class LoggerStages(LightningEnum): - """ Train/validation/test phase in each training step. - - >>> # you can math the type with string - >>> LoggerStages.TRAIN == 'train' - True - """ - TRAIN = "train" - VAL = "validation" - TEST = "test" - - @staticmethod - def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages': - if isinstance(stage_or_testing, str) and stage_or_testing in list(LoggerStages): - return LoggerStages(stage_or_testing) - if isinstance(stage_or_testing, (bool, int)): - # stage_or_testing is trainer.testing - return LoggerStages.TEST if bool(stage_or_testing) else LoggerStages.VAL - raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given") - - class ResultStoreType(LightningEnum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" @@ -371,7 +351,7 @@ def update_logger_connector(self) -> None: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage in LoggerStages.TRAIN.value + is_train = self._stage in RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d4bcc8f97e302..04de8c7d131a3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -75,7 +75,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self.trainer.running_stage) # type: ignore + return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) @@ -117,7 +117,7 @@ def on_train_batch_end(self) -> None: self.cached_results._batch_size = None def cache_logged_metrics(self): - self._cached_results[self.trainer.running_stage].cache_result() + self._cached_results[self.trainer._running_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e4b3f7acfb9d1..68c1052a58403 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -14,7 +14,6 @@ import torch from pytorch_lightning.core.step_result import EvalResult, Result -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -173,7 +172,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref._results = Result() # run actual test step - if self.trainer.running_stage == RunningStage.PREDICTING: + if self.trainer.is_predicting: model_ref._current_fx_name = "predict" forward_output = self.trainer.accelerator_backend.forward([args[0]]) self._predictions[dataloader_idx].append(forward_output) @@ -310,11 +309,7 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_predict_epoch_end(self): - model_ref = self.trainer.get_model() - results = self._predictions - if is_overridden('predict_epoch_end', model=model_ref): - results = model_ref.predict_epoch_end(results) def _convert_to_numpy(v): return v.cpu().numpy() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b7f241fd4351..1c4f865e21829 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -295,7 +295,8 @@ def __init__( super().__init__() self._device_type = DeviceType.CPU self._distrib_type = None - self._running_stage = None + self._running_stage = RunningStage.UNDEFINED + self.is_predicting = False # init connectors self.dev_debugger = InternalDebugger(self) @@ -416,8 +417,6 @@ def __init__( # last thing are the plugins which override whatever the trainer used by default self.plugin_connector.on_trainer_init(plugins) - self.running_stage = RunningStage.UNDEFINED - # Callback system self.on_init_end() @@ -500,20 +499,20 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self.running_stage = RunningStage.UNDEFINED + self._set_running_stage(RunningStage.UNDEFINED) return results or 1 def _set_running_stage(self, stage): model_ref = self.get_model() # predicting is special and shouldn't be overriden - if self.running_stage == RunningStage.PREDICTING: + if self._running_stage == RunningStage.PREDICTING: stage = RunningStage.PREDICTING if model_ref is not None: model_ref.running_stage = stage - self.running_stage = stage + self._running_stage = stage def train(self): self.run_sanity_check(self.get_model()) @@ -632,7 +631,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) - if self.running_stage == RunningStage.PREDICTING: + if self.is_predicting: continue output = self.evaluation_loop.evaluation_step_end(output) @@ -648,7 +647,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - if self.running_stage == RunningStage.PREDICTING: + if self.is_predicting: return self.evaluation_loop.on_predict_epoch_end() # lightning module method @@ -786,7 +785,7 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - self.running_stage = RunningStage.UNDEFINED + self._running_stage = RunningStage.UNDEFINED return results @@ -815,7 +814,7 @@ def predict( # -------------------- # SETUP HOOK # -------------------- - self.running_stage = RunningStage.PREDICTING + self.is_predicting = True self.verbose_test = verbose if not dataloaders: @@ -832,13 +831,14 @@ def predict( if dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - self.testing = True + os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) - self.testing = False self.teardown('test') + self.testing = False + del os.environ['PL_TESTING_MODE'] - self.running_stage = RunningStage.UNDEFINED + self.is_predicting = False return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 8c8f1649e73c7..6ae0278d3dfd6 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -4,6 +4,7 @@ import torch from pytorch_lightning.overrides.data_parallel import LightningDistributedModule +from pytorch_lightning.trainer.states import RunningStage def test_lightning_distributed_module_methods(): @@ -14,18 +15,15 @@ def test_lightning_distributed_module_methods(): batch = torch.rand(5) batch_idx = 3 - pl_module.training = True - pl_module.testing = False + pl_module.running_stage = RunningStage.TRAINING dist_module(batch, batch_idx) pl_module.training_step.assert_called_with(batch, batch_idx) - pl_module.training = False - pl_module.testing = True + pl_module.running_stage = RunningStage.TESTING dist_module(batch, batch_idx) pl_module.test_step.assert_called_with(batch, batch_idx) - pl_module.training = False - pl_module.testing = False + pl_module.running_stage = RunningStage.EVALUATING dist_module(batch, batch_idx) pl_module.validation_step.assert_called_with(batch, batch_idx) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 08bf4df42cea5..556053f1969fb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1469,7 +1469,8 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None): accelerator=accelerator, gpus=gpus, num_processes=num_processes, - plugins=plugins + plugins=plugins, + num_sanity_val_steps=0 ) results = trainer.predict(model, dataloaders) # todo: address this in another PR @@ -1480,7 +1481,7 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None): def test_trainer_predict_cpu(tmpdir): - predict(tmpdir, None, None, None) + predict(tmpdir, None, None, 1) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") From bb68b9ede954e78a087b282abd61a7ba9b75e3b2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 15:31:29 +0000 Subject: [PATCH 18/43] resolve path --- pytorch_lightning/trainer/deprecated_api.py | 5 ++++- pytorch_lightning/trainer/trainer.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index dbfa3258b2ed1..46b2104f992bf 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -153,4 +153,7 @@ def testing(self, val: bool) -> None: if val: self._running_stage = RunningStage.TESTING else: - self._running_stage = None + if self._running_stage == RunningStage.TRAINING: + pass + else: + self._running_stage = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1c4f865e21829..6c04062c78dd8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -779,13 +779,15 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + os.environ['PL_TESTING_MODE'] = '1' if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + del os.environ['PL_TESTING_MODE'] self.teardown('test') - self._running_stage = RunningStage.UNDEFINED + self._set_running_stage(RunningStage.UNDEFINED) return results @@ -814,7 +816,6 @@ def predict( # -------------------- # SETUP HOOK # -------------------- - self.is_predicting = True self.verbose_test = verbose if not dataloaders: @@ -831,14 +832,19 @@ def predict( if dataloaders is not None: self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + # set path variable + self.is_predicting = True os.environ['PL_TESTING_MODE'] = '1' self.model = model + results = self.fit(model) + + # unset path variable self.teardown('test') - self.testing = False del os.environ['PL_TESTING_MODE'] - self.is_predicting = False + self._set_running_stage(RunningStage.UNDEFINED) + return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -875,11 +881,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True - os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False - del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): @@ -896,10 +900,8 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval - self.testing = True self.model = model results = self.fit(model) - self.testing = False # teardown if self.is_function_implemented('teardown'): From dc6735d3f5b701ed1b031a56bc6523264447a405 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 15:40:06 +0000 Subject: [PATCH 19/43] resolve bug --- pytorch_lightning/trainer/trainer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6c04062c78dd8..4520f83fe19d2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -505,12 +505,15 @@ def fit( def _set_running_stage(self, stage): model_ref = self.get_model() - # predicting is special and shouldn't be overriden - if self._running_stage == RunningStage.PREDICTING: - stage = RunningStage.PREDICTING - + + # WARNING: With predicting, + # trainer _running_state should be RunningStage.TESTING + # however, the model running_stage should be RunningStage.PREDICTING or None if model_ref is not None: - model_ref.running_stage = stage + if self.is_predicting: + model_ref.running_stage = RunningStage.PREDICTING + else: + model_ref.running_stage = stage self._running_stage = stage From cf381439d9f4409281446dbb6464a530675513d8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 15:54:15 +0000 Subject: [PATCH 20/43] update on comments --- pytorch_lightning/trainer/trainer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4520f83fe19d2..1576ac76b2f57 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -505,7 +505,7 @@ def fit( def _set_running_stage(self, stage): model_ref = self.get_model() - + # WARNING: With predicting, # trainer _running_state should be RunningStage.TESTING # however, the model running_stage should be RunningStage.PREDICTING or None @@ -798,20 +798,19 @@ def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - verbose: bool = True, ): r""" Separates from fit to make sure you never run on your predictions set until you want to. + This will call the model forward function to compute predictions. + Args: - model: The model to test. + model: The model to predict on. dataloaders: Either a single Pytorch Dataloader or a list of them, specifying inference samples. - verbose: If True, prints the test results - Returns: The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries """ @@ -819,11 +818,10 @@ def predict( # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose - - if not dataloaders: + if (not isinstance(dataloaders, DataLoader) and not isinstance(dataloaders, (list, tuple))) \ + or (isinstance(dataloaders, (list, tuple)) and not all(isinstance(d, DataLoader) for d in dataloaders)): raise MisconfigurationException( - 'You need to pass dataloaders to trainer.predict. ' + 'You need to pass a dataloader or a list of dataloaders to trainer.predict. ' ) if model is None: From 6de15f32ff0cd97593398165a38e01198f1d7633 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 16:03:04 +0000 Subject: [PATCH 21/43] resolve doc --- pytorch_lightning/core/lightning.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e2f9974352e86..92d788c1d1497 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -932,38 +932,51 @@ def test_epoch_end( ) -> None: """ Called at the end of a test epoch with the output of all test steps. + .. code-block:: python + # the pseudocode for these calls test_outs = [] for test_batch in test_data: out = test_step(test_batch) test_outs.append(out) test_epoch_end(test_outs) + Args: outputs: List of outputs you defined in :meth:`test_step_end`, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader + Return: None + Note: If you didn't define a :meth:`test_step`, this won't be called. + Examples: With a single dataloader: + .. code-block:: python + def test_epoch_end(self, outputs): # do something with the outputs of all test batches all_test_preds = test_step_outputs.predictions + some_result = calc_all_results(all_test_preds) self.log(some_result) + With multiple dataloaders, `outputs` will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each test step for that dataloader. + .. code-block:: python + def test_epoch_end(self, outputs): final_value = 0 for dataloader_outputs in outputs: for test_step_out in dataloader_outputs: # do something final_value += test_step_out + self.log('final_metric', final_value) """ From 147bdd796b1ad28685aa703d9b1fd8317b0140b5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 16:45:00 +0000 Subject: [PATCH 22/43] resolve bug --- pytorch_lightning/trainer/trainer.py | 127 ++++++++++++++------------- 1 file changed, 64 insertions(+), 63 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1576ac76b2f57..e9d4f8d936518 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -506,6 +506,10 @@ def fit( def _set_running_stage(self, stage): model_ref = self.get_model() + # todo: clean up this routing mess. + if self._running_stage == RunningStage.TESTING and stage != RunningStage.TESTING: + stage = RunningStage.TESTING + # WARNING: With predicting, # trainer _running_state should be RunningStage.TESTING # however, the model running_stage should be RunningStage.PREDICTING or None @@ -748,31 +752,25 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" - Separates from fit to make sure you never run on your test set until you want to. - Args: ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the weights from the last epoch to test. Default to ``best``. - datamodule: A instance of :class:`LightningDataModule`. - model: The model to test. - test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. - verbose: If True, prints the test results - Returns: The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries """ # -------------------- # SETUP HOOK # -------------------- - self._set_running_stage(RunningStage.TESTING) self.verbose_test = verbose + self._set_running_stage(RunningStage.TESTING) + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( @@ -782,68 +780,13 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') - os.environ['PL_TESTING_MODE'] = '1' if model is not None: results = self.__test_given_model(model, test_dataloaders) else: results = self.__test_using_best_weights(ckpt_path, test_dataloaders) - del os.environ['PL_TESTING_MODE'] self.teardown('test') - self._set_running_stage(RunningStage.UNDEFINED) - - return results - - def predict( - self, - model: Optional[LightningModule] = None, - dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, - ): - r""" - - Separates from fit to make sure you never run on your predictions set until you want to. - - This will call the model forward function to compute predictions. - - Args: - model: The model to predict on. - - dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying inference samples. - - Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries - """ - - # -------------------- - # SETUP HOOK - # -------------------- - if (not isinstance(dataloaders, DataLoader) and not isinstance(dataloaders, (list, tuple))) \ - or (isinstance(dataloaders, (list, tuple)) and not all(isinstance(d, DataLoader) for d in dataloaders)): - raise MisconfigurationException( - 'You need to pass a dataloader or a list of dataloaders to trainer.predict. ' - ) - - if model is None: - raise MisconfigurationException( - 'You need to pass a model to trainer.predict. ' - ) - - # attach data - if dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) - - # set path variable - self.is_predicting = True - os.environ['PL_TESTING_MODE'] = '1' - self.model = model - - results = self.fit(model) - # unset path variable - self.teardown('test') - del os.environ['PL_TESTING_MODE'] - self.is_predicting = False self._set_running_stage(RunningStage.UNDEFINED) return results @@ -882,9 +825,11 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): # run tests self.tested_ckpt_path = ckpt_path self.testing = True + os.environ['PL_TESTING_MODE'] = '1' self.model = model results = self.fit(model) self.testing = False + del os.environ['PL_TESTING_MODE'] # teardown if self.is_function_implemented('teardown'): @@ -901,8 +846,10 @@ def __test_given_model(self, model, test_dataloaders): # run test # sets up testing so we short circuit to eval + self.testing = True self.model = model results = self.fit(model) + self.testing = False # teardown if self.is_function_implemented('teardown'): @@ -910,6 +857,60 @@ def __test_given_model(self, model, test_dataloaders): return results + def predict( + self, + model: Optional[LightningModule] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ): + r""" + + Separates from fit to make sure you never run on your predictions set until you want to. + + This will call the model forward function to compute predictions. + + Args: + model: The model to predict on. + + dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying inference samples. + + Returns: + The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + """ + + # -------------------- + # SETUP HOOK + # -------------------- + if (not isinstance(dataloaders, DataLoader) and not isinstance(dataloaders, (list, tuple))) \ + or (isinstance(dataloaders, (list, tuple)) and not all(isinstance(d, DataLoader) for d in dataloaders)): + raise MisconfigurationException( + 'You need to pass a dataloader or a list of dataloaders to trainer.predict. ' + ) + + if model is None: + raise MisconfigurationException( + 'You need to pass a model to trainer.predict. ' + ) + + # attach data + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + + # set path variable + self.is_predicting = True + os.environ['PL_TESTING_MODE'] = '1' + self.model = model + + results = self.fit(model) + + # unset path variable + self.teardown('test') + del os.environ['PL_TESTING_MODE'] + self.is_predicting = False + self._set_running_stage(RunningStage.UNDEFINED) + + return results + def tune( self, model: LightningModule, From 895a4ba271906547b955defecd20667d5bc0c6ee Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 17:25:07 +0000 Subject: [PATCH 23/43] update --- tests/overrides/test_data_parallel.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 6ae0278d3dfd6..e61b81fd8488e 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -38,16 +38,13 @@ def test_lightning_distributed_module_warn_none_output(): pl_module.test_step.return_value = None with pytest.warns(UserWarning, match="Your training_step returned None"): - pl_module.training = True - pl_module.testing = False + pl_module.running_stage = RunningStage.TRAINING dist_module() with pytest.warns(UserWarning, match="Your test_step returned None"): - pl_module.training = False - pl_module.testing = True + pl_module.running_stage = RunningStage.TESTING dist_module() with pytest.warns(UserWarning, match="Your validation_step returned None"): - pl_module.training = False - pl_module.testing = False + pl_module.running_stage = RunningStage.EVALUATING dist_module() From 45ebba9f9babb34af7bd0304be614a1bad9a2ef5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 20 Jan 2021 19:06:10 +0000 Subject: [PATCH 24/43] resolve bug --- tests/base/develop_pipelines.py | 2 +- tests/models/test_restore.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 4949d53fc9a50..9f9e541cbffbc 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -14,7 +14,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index e29afe8e66e55..5629e00236f3f 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -25,7 +25,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from tests.base import BoringModel, EvalModelTemplate, GenericEvalModelTemplate @@ -398,6 +398,7 @@ def assert_good_acc(): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() + dp_model.module.running_stage = RunningStage.EVALUATING dataloader = trainer.train_dataloader tpipes.run_prediction(dp_model, dataloader, dp=True) From e58c4e7875324c174c2ad58e8d4ab089fae2fdfe Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 19:44:33 +0000 Subject: [PATCH 25/43] update on comments --- pytorch_lightning/core/lightning.py | 2 ++ pytorch_lightning/overrides/data_parallel.py | 4 +-- pytorch_lightning/overrides/fairscale.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 1 - pytorch_lightning/trainer/states.py | 1 - pytorch_lightning/trainer/trainer.py | 27 +++++++++++++------- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 92d788c1d1497..5a3edae08c5a6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -66,6 +66,7 @@ class LightningModule( "on_gpu", "current_epoch", "global_step", + "running_stage", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): @@ -102,6 +103,7 @@ def __init__(self, *args, **kwargs): self._running_manual_backward = False self._current_hook_fx_name = None self._current_dataloader_idx = None + self.running_stage = None def optimizers(self): opts = self.trainer.optimizers diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index dffc6eae01328..a38f2863c94d2 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -82,7 +82,7 @@ def forward(self, *inputs, **kwargs): if len(self.device_ids) == 1: - running_stage = getattr(self.module, "running_stage") + running_stage = self.module.running_stage if running_stage == RunningStage.TRAINING: return self.module.training_step(*inputs[0], **kwargs[0]) @@ -197,7 +197,7 @@ def __init__(self, pl_module: LightningModule): def forward(self, *inputs, **kwargs): - running_stage = getattr(self.module, "running_stage") + running_stage = self.module.running_stage if running_stage == RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 65705a62e2310..724054751a60b 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -24,7 +24,7 @@ def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: self.sync_buffers() - running_stage = getattr(self.module, "running_stage") + running_stage = self.module.running_stage if running_stage == RunningStage.TRAINING: outputs = self.module.training_step(*inputs, **kwargs) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 68c1052a58403..30329a87a7c44 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -170,7 +170,6 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() - # run actual test step if self.trainer.is_predicting: model_ref._current_fx_name = "predict" diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 4112da357877e..a3ef08df1e49e 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -43,7 +43,6 @@ class RunningStage(LightningEnum): >>> RunningStage.TRAINING == 'train' True """ - UNDEFINED = None TRAINING = 'train' EVALUATING = 'eval' TESTING = 'test' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e9d4f8d936518..133c7ea5bba62 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -295,7 +295,7 @@ def __init__( super().__init__() self._device_type = DeviceType.CPU self._distrib_type = None - self._running_stage = RunningStage.UNDEFINED + self._running_stage = None self.is_predicting = False # init connectors @@ -499,15 +499,20 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self._set_running_stage(RunningStage.UNDEFINED) + self._set_running_stage(None) return results or 1 def _set_running_stage(self, stage): model_ref = self.get_model() + if stage is None: + self._running_stage = stage + model_ref.running_stage = stage + return + # todo: clean up this routing mess. - if self._running_stage == RunningStage.TESTING and stage != RunningStage.TESTING: + if self._running_stage == RunningStage.TESTING: stage = RunningStage.TESTING # WARNING: With predicting, @@ -761,6 +766,7 @@ def test( test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. verbose: If True, prints the test results + Returns: The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries """ @@ -787,7 +793,7 @@ def test( self.teardown('test') - self._set_running_stage(RunningStage.UNDEFINED) + self._set_running_stage(None) return results @@ -881,15 +887,18 @@ def predict( # -------------------- # SETUP HOOK # -------------------- - if (not isinstance(dataloaders, DataLoader) and not isinstance(dataloaders, (list, tuple))) \ - or (isinstance(dataloaders, (list, tuple)) and not all(isinstance(d, DataLoader) for d in dataloaders)): + if not ( + isinstance(dataloaders, DataLoader) + or isinstance(dataloaders, (list, tuple)) + and all(isinstance(d, DataLoader) for d in dataloaders) + ): raise MisconfigurationException( - 'You need to pass a dataloader or a list of dataloaders to trainer.predict. ' + 'You need to pass a dataloader or a list of dataloaders to `trainer.predict`. ' ) if model is None: raise MisconfigurationException( - 'You need to pass a model to trainer.predict. ' + 'You need to pass a model to `trainer.predict`. ' ) # attach data @@ -907,7 +916,7 @@ def predict( self.teardown('test') del os.environ['PL_TESTING_MODE'] self.is_predicting = False - self._set_running_stage(RunningStage.UNDEFINED) + self._set_running_stage(None) return results From 6b4f76fae71d83565ebbb9d6ea9d3ba687fc1e86 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 19:46:59 +0000 Subject: [PATCH 26/43] resolve pep8 --- tests/base/develop_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index 9f9e541cbffbc..4949d53fc9a50 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -14,7 +14,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import DistributedType from tests.base import BoringModel from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed From f0bdbd372ab421ba0ae87bd62a8011d59714a08c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 19:59:33 +0000 Subject: [PATCH 27/43] update test doc --- pytorch_lightning/trainer/trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 133c7ea5bba62..bbcc8894bca75 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -757,14 +757,20 @@ def test( datamodule: Optional[LightningDataModule] = None, ): r""" + Separates from fit to make sure you never run on your test set until you want to. + Args: ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the weights from the last epoch to test. Default to ``best``. + datamodule: A instance of :class:`LightningDataModule`. + model: The model to test. + test_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. + verbose: If True, prints the test results Returns: From 0a2efb22c90bbff4c152db35ba8122b531021bcd Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 20 Jan 2021 20:24:12 +0000 Subject: [PATCH 28/43] update on comments --- .../logger_connector/epoch_result_store.py | 23 ++++++++++++++++++- .../logger_connector/logger_connector.py | 1 + pytorch_lightning/trainer/deprecated_api.py | 8 +++---- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 0711a3fb3a25e..46263f3a7e74c 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict -from typing import Any, Dict, List, Optional +from enum import Enum +from typing import Any, Dict, List, Optional, Union import torch @@ -21,6 +22,26 @@ from pytorch_lightning.utilities import DistributedType, LightningEnum +class LoggerStages(str, Enum): + """ Train/validation/test phase in each training step. + >>> # you can math the type with string + >>> LoggerStages.TRAIN == 'train' + True + """ + TRAIN = "train" + VAL = "validation" + TEST = "test" + + @staticmethod + def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages': + if isinstance(stage_or_testing, str) and stage_or_testing in list(LoggerStages): + return LoggerStages(stage_or_testing) + if isinstance(stage_or_testing, (bool, int)): + # stage_or_testing is trainer.testing + return LoggerStages.TEST if bool(stage_or_testing) else LoggerStages.VAL + raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given") + + class ResultStoreType(LightningEnum): INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop" OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop" diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 04de8c7d131a3..954c053db6201 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -39,6 +39,7 @@ def __init__(self, trainer): self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} + self._cached_results[None] = EpochResultStore(trainer, None) self._callback_hook_validator = CallbackHookNameValidator() @property diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 46b2104f992bf..05c419957e95b 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -152,8 +152,6 @@ def testing(self) -> bool: def testing(self, val: bool) -> None: if val: self._running_stage = RunningStage.TESTING - else: - if self._running_stage == RunningStage.TRAINING: - pass - else: - self._running_stage = None + + elif self._running_stage != RunningStage.TRAINING: + self._running_stage = None From 992f360d182e940f6618e2d89b2996a30fcadeb9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 21 Jan 2021 08:19:35 +0000 Subject: [PATCH 29/43] solve special tests --- tests/special_tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ae15bc57b1348..d4033f8da88b9 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -23,5 +23,5 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp -python ${DEFAULTS} pytest tests/trainer/test_trainer.py::test_trainer_predict_ddp -python ${DEFAULTS} pytest tests/trainer/test_trainer.py::test_trainer_predict_dp +python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp +python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp From a366c2c574706d5310fbe4eb709efeef5e986c7d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 21 Jan 2021 09:40:52 +0000 Subject: [PATCH 30/43] resolve bug --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 3 +++ pytorch_lightning/trainer/trainer.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 80d4c8952a1f3..f249381c8d6c9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model): # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) - if not has_train_dataloader: + if not has_train_dataloader and not self.trainer.is_predicting: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model): # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) - if not has_optimizers: + if not has_optimizers and not self.trainer.is_predicting: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 30329a87a7c44..69ca810b420e0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -308,6 +308,9 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_predict_epoch_end(self): + self.trainer._progress_bar_callback.on_test_end( + self.trainer, self.trainer.get_model()) + results = self._predictions def _convert_to_numpy(v): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bbcc8894bca75..0cd233e34761f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -774,7 +774,7 @@ def test( verbose: If True, prints the test results Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ # -------------------- # SETUP HOOK From 9137b16068fe03e6db8df548235363e5f5476aac Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 11:13:45 +0000 Subject: [PATCH 31/43] resolve flake8 --- pytorch_lightning/trainer/evaluation_loop.py | 1 - tests/base/model_test_steps.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 7bf56791eacd1..c68c17a361201 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -16,7 +16,6 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index dfbbd7d2d31e6..db70959bfddef 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random from abc import ABC from collections import OrderedDict From b4b860fe674b15fb531d3e2a7d95f73965757af6 Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 26 Jan 2021 12:12:57 +0000 Subject: [PATCH 32/43] Update pytorch_lightning/callbacks/progress.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 4dfbbdd2a40ad..dc25357fc198e 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -294,8 +294,8 @@ def init_validation_tqdm(self) -> tqdm: def init_test_tqdm(self, trainer=None) -> tqdm: """ Override this to customize the tqdm bar for testing. """ desc = "Testing" - if trainer is not None: - desc = "Predicting" if getattr(trainer, "is_predicting", False) else desc + if trainer is not None and getattr(trainer, "is_predicting", False): + desc = "Predicting" bar = tqdm( desc=desc, position=(2 * self.process_position), From 7526efe577fe11efff5595ebfc5f4712ac0b4f64 Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 26 Jan 2021 12:13:10 +0000 Subject: [PATCH 33/43] Update pytorch_lightning/trainer/trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 74af3f5e2d5e0..928800590f0e0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -906,7 +906,7 @@ def predict( if datamodule is not None: # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model, datamodule, 'test') # attach data if dataloaders is not None: From dca009d8d71b78b9e632bc6d7f2d00ac34610b36 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 13:22:34 +0000 Subject: [PATCH 34/43] add predict to LightningModule --- pytorch_lightning/accelerators/cpu_accelerator.py | 4 ++-- pytorch_lightning/accelerators/ddp2_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_hpc_accelerator.py | 2 +- pytorch_lightning/accelerators/ddp_spawn_accelerator.py | 2 +- pytorch_lightning/accelerators/dp_accelerator.py | 2 +- pytorch_lightning/accelerators/gpu_accelerator.py | 4 ++-- pytorch_lightning/accelerators/horovod_accelerator.py | 4 ++-- pytorch_lightning/accelerators/tpu_accelerator.py | 4 ++-- pytorch_lightning/core/lightning.py | 6 ++++++ pytorch_lightning/overrides/data_parallel.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- tests/trainer/test_trainer.py | 2 +- 14 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index 7033e217e34bb..de4e0a7d2fd14 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -79,8 +79,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def forward(self, args): - return self._step(self.trainer.model.forward, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) def sync_tensor(self, tensor: Union[torch.Tensor], diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 9a701a4341b7b..29b5ac0b88d32 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -66,7 +66,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index fa2db84865e7a..7eccc48a5abf5 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -164,7 +164,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 1ebb920e45952..3dda3ac6ef465 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -180,7 +180,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 4797fd17af2be..b576841b3a829 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -83,7 +83,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index 93fcc1d1dc0f3..a26db97ce84f2 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -214,7 +214,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def _step(self, args): diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 641dd9ee7da15..dc5a6bacb0abf 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -134,7 +134,7 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) - def forward(self, args): + def predict(self, args): return self._step(args) def training_step_end(self, output): diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index 8e0922d2f64fe..db39192dc512f 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -87,8 +87,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def forward(self, args): - return self._step(self.trainer.model.forward, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) def to_device(self, batch): gpu_id = 0 diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index a16d882d8e64a..bdb55a32e8a06 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -136,8 +136,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def forward(self, args): - return self._step(self.trainer.model.forward, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 06befbcf0489a..f1d502125aedc 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -159,8 +159,8 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) - def forward(self, args): - return self._step(self.trainer.model.forward, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e47bbdba89111..613a50cbd83d9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -977,6 +977,12 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ + def predict(self, x: Any): + """ + Use this function with trainer.predict(...). Override if you need to add any processing logic. + """ + return self(x) + def configure_optimizers( self, ): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index a38f2863c94d2..d94252c80a565 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -212,7 +212,7 @@ def forward(self, *inputs, **kwargs): warn_if_output_is_none(output, "validation_step") else: - output = self.module(*inputs, **kwargs) + output = self.module.predict(*inputs, **kwargs) return output @@ -307,7 +307,7 @@ def _worker(i, module, input, kwargs, device=None): fx_called = 'validation_step' else: - output = module(*input, **kwargs) + output = module.predict(*input, **kwargs) fx_called = 'forward' if output is None: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c1588842ed30d..52a5dd30ddd8a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -165,8 +165,8 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref._results = Result() if self.trainer.is_predicting: - model_ref._current_fx_name = "forward" - forward_output = self.trainer.accelerator_backend.forward([args[0]]) + model_ref._current_fx_name = "predict" + forward_output = self.trainer.accelerator_backend.predict([args[0]]) self._predictions[dataloader_idx].append(forward_output) self.trainer._progress_bar_callback.on_test_batch_end( self.trainer, model_ref, forward_output, batch, batch_idx, dataloader_idx) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 8d853a155ffae..c9440bb2d4e02 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1524,7 +1524,7 @@ def test_trainer_predict_1_gpu(tmpdir): def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) - + def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" with pytorch_profiler.profile("test_step"): From 1b1a8aa6593a428c8c3df97aa3c2102910fe3d75 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 13:25:32 +0000 Subject: [PATCH 35/43] missing predict --- pytorch_lightning/overrides/data_parallel.py | 4 ++-- pytorch_lightning/overrides/fairscale.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index d94252c80a565..8ba67aba8b3e4 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -94,7 +94,7 @@ def forward(self, *inputs, **kwargs): return self.module.validation_step(*inputs[0], **kwargs[0]) else: - return self.module(*inputs[0], **kwargs[0]) + return self.module.predict(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) @@ -308,7 +308,7 @@ def _worker(i, module, input, kwargs, device=None): else: output = module.predict(*input, **kwargs) - fx_called = 'forward' + fx_called = 'predict' if output is None: warn_missing_output(fx_called) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 724054751a60b..f413065f627ff 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -36,6 +36,6 @@ def forward(self, *inputs, **kwargs): outputs = self.module.validation_step(*inputs, **kwargs) else: - outputs = self.module(*inputs, **kwargs) + outputs = self.module.predict(*inputs, **kwargs) return outputs From 21710f96b96da7fe875574d3c61ec81daf515726 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 13:34:42 +0000 Subject: [PATCH 36/43] typo --- pytorch_lightning/trainer/evaluation_loop.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 52a5dd30ddd8a..7391bf4619640 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -166,10 +166,10 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): if self.trainer.is_predicting: model_ref._current_fx_name = "predict" - forward_output = self.trainer.accelerator_backend.predict([args[0]]) - self._predictions[dataloader_idx].append(forward_output) + predictions = self.trainer.accelerator_backend.predict([args[0]]) + self._predictions[dataloader_idx].append(predictions) self.trainer._progress_bar_callback.on_test_batch_end( - self.trainer, model_ref, forward_output, batch, batch_idx, dataloader_idx) + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx) return elif self.testing: From 4752bd7042cae1c2aa379df676fb2047b7d7dbcf Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 13:38:35 +0000 Subject: [PATCH 37/43] rename is_prediction to _predicting --- pytorch_lightning/trainer/configuration_validator.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index f249381c8d6c9..e12eacb9cfce4 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model): # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) - if not has_train_dataloader and not self.trainer.is_predicting: + if not has_train_dataloader and not self.trainer._predicting: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model): # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) - if not has_optimizers and not self.trainer.is_predicting: + if not has_optimizers and not self.trainer._predicting: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 7391bf4619640..8b0b42fb8d2c6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -164,7 +164,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() - if self.trainer.is_predicting: + if self.trainer._predicting: model_ref._current_fx_name = "predict" predictions = self.trainer.accelerator_backend.predict([args[0]]) self._predictions[dataloader_idx].append(predictions) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 928800590f0e0..0051ce314f85f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -296,7 +296,7 @@ def __init__( self._device_type = DeviceType.CPU self._distrib_type = None self._running_stage = None - self.is_predicting = False + self._predicting = False # init connectors self.dev_debugger = InternalDebugger(self) @@ -519,7 +519,7 @@ def _set_running_stage(self, stage): # trainer _running_state should be RunningStage.TESTING # however, the model running_stage should be RunningStage.PREDICTING or None if model_ref is not None: - if self.is_predicting: + if self._predicting: model_ref.running_stage = RunningStage.PREDICTING else: model_ref.running_stage = stage @@ -643,7 +643,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) - if self.is_predicting: + if self._predicting: continue output = self.evaluation_loop.evaluation_step_end(output) @@ -659,7 +659,7 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) - if self.is_predicting: + if self._predicting: return self.evaluation_loop.on_predict_epoch_end() # lightning module method @@ -913,7 +913,7 @@ def predict( self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) # set path variable - self.is_predicting = True + self._predicting = True os.environ['PL_TESTING_MODE'] = '1' self.model = model @@ -922,7 +922,7 @@ def predict( # unset path variable self.teardown('test') del os.environ['PL_TESTING_MODE'] - self.is_predicting = False + self._predicting = False self._set_running_stage(None) return results From 43442a086dec2e0c327cda6b670b386c047a9526 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 13:47:33 +0000 Subject: [PATCH 38/43] add --- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 613a50cbd83d9..06ced12cfa5d9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -977,11 +977,11 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ - def predict(self, x: Any): + def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): """ Use this function with trainer.predict(...). Override if you need to add any processing logic. """ - return self(x) + return self(batch) def configure_optimizers( self, diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8b0b42fb8d2c6..c168693aac5f0 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -166,7 +166,7 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): if self.trainer._predicting: model_ref._current_fx_name = "predict" - predictions = self.trainer.accelerator_backend.predict([args[0]]) + predictions = self.trainer.accelerator_backend.predict(args) self._predictions[dataloader_idx].append(predictions) self.trainer._progress_bar_callback.on_test_batch_end( self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx) From e7fd0d6be3ef707d5088838a607e2485791764a1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 16:53:04 +0000 Subject: [PATCH 39/43] update --- tests/trainer/test_trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c9440bb2d4e02..c6883eb51bb7c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1452,18 +1452,23 @@ def test_trainer_profiler_incorrect_arg_type(profiler): Trainer(profiler=profiler) +class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True): dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), torch.utils.data.DataLoader(RandomDataset(32, 2))] - class TestLightningDataModule(LightningDataModule): - - def test_dataloader(self): - return dataloaders - model = BoringModel() - datamodule = TestLightningDataModule() + datamodule = TestLightningDataModule(dataloaders) trainer = Trainer( default_root_dir=tmpdir, From 5a3a110de2a63a10f2acb925412290690e4bf23f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 20:16:22 +0000 Subject: [PATCH 40/43] update --- tests/trainer/test_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4ff64143390eb..e21351704fd4c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1489,6 +1489,8 @@ def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=T assert results[0][0].shape == torch.Size([1, 2]) +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") @pytest.mark.parametrize('datamodule', [False, True]) def test_trainer_predict_cpu(tmpdir, datamodule): predict(tmpdir, None, None, 1, datamodule=datamodule) From bd5c4c57c826133ff62a2e5f8cdd4ec1db4ecd07 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 26 Jan 2021 20:56:00 +0000 Subject: [PATCH 41/43] update doc --- docs/source/starter/introduction_guide.rst | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 870e9db4a9ef0..0456ea35f9567 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -881,7 +881,27 @@ Or maybe we have a model that we use to do generation z = sample_noise() generated_imgs = model(z) -How you split up what goes in ``forward`` vs ``training_step`` depends on how you want to use this model for + +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function +By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. + +.. code-block:: python + + class LitMNISTDreamer(LightningModule): + + def forward(self, z): + imgs = self.decoder(z) + return imgs + + def predict(self, batch, batch_idx: int , dataloader_idx: int = None): + return self(batch) + + + model = LitMNISTDreamer() + trainer.predict(model, datamodule) + + +How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for prediction. ---------------- From a29071957426a1543e7f8f841ee23993f8d319bc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 27 Jan 2021 09:43:25 +0100 Subject: [PATCH 42/43] chlog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 539c08e69ffe0..5c60c241516e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) + - Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) @@ -68,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464)) +- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579)) + + ### Changed - Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) From 30da4da05103613f8856f2aadde107482787f5ac Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 27 Jan 2021 10:13:10 +0100 Subject: [PATCH 43/43] Apply suggestions from code review --- pytorch_lightning/trainer/evaluation_loop.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c168693aac5f0..2aa6f86dc0b9a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -169,7 +169,8 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): predictions = self.trainer.accelerator_backend.predict(args) self._predictions[dataloader_idx].append(predictions) self.trainer._progress_bar_callback.on_test_batch_end( - self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx) + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx + ) return elif self.testing: @@ -288,8 +289,7 @@ def __auto_reduce_result_objs(self, outputs): return eval_results def on_predict_epoch_end(self): - self.trainer._progress_bar_callback.on_test_end( - self.trainer, self.trainer.get_model()) + self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.get_model()) results = self._predictions diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7ad95242bb12f..28f2c9de333de 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -901,9 +901,7 @@ def predict( ) if model is None: - raise MisconfigurationException( - 'You need to pass a model to `trainer.predict`. ' - ) + raise MisconfigurationException('You need to pass a model to `trainer.predict`. ') if datamodule is not None: # Attach datamodule to get setup/prepare_data added to model before the call to it below