diff --git a/CHANGELOG.md b/CHANGELOG.md index 024d92d22f782..0e465056457c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Add support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) + - Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) @@ -68,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464)) +- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579)) + + - Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479)) @@ -120,7 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `TrainResult` ([#5323](https://github.com/PyTorchLightning/pytorch-lightning/pull/5323)) - + - Removed deprecated `EvalResult` ([#5633](https://github.com/PyTorchLightning/pytorch-lightning/pull/5633)) @@ -155,7 +159,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) - Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417)) -- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406)) +- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406)) - Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743)) diff --git a/docs/source/starter/introduction_guide.rst b/docs/source/starter/introduction_guide.rst index 870e9db4a9ef0..a7e7727f3152a 100644 --- a/docs/source/starter/introduction_guide.rst +++ b/docs/source/starter/introduction_guide.rst @@ -881,8 +881,30 @@ Or maybe we have a model that we use to do generation z = sample_noise() generated_imgs = model(z) -How you split up what goes in ``forward`` vs ``training_step`` depends on how you want to use this model for + +To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function +By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic. + +.. code-block:: python + + class LitMNISTDreamer(LightningModule): + + def forward(self, z): + imgs = self.decoder(z) + return imgs + + def predict(self, batch, batch_idx: int , dataloader_idx: int = None): + return self(batch) + + + model = LitMNISTDreamer() + trainer.predict(model, datamodule) + + +How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for prediction. +However, we recommend ``forward`` to contain only tensor operation with your model, ``training_step`` to encapsulate ``forward`` logic with logging, +metrics and loss computation and ``predict`` to encapsulate ``forward`` with preprocess, postprocess functions. ---------------- diff --git a/pytorch_lightning/accelerators/legacy/cpu_accelerator.py b/pytorch_lightning/accelerators/legacy/cpu_accelerator.py index f34162c602a55..b1ad39eaad042 100644 --- a/pytorch_lightning/accelerators/legacy/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/cpu_accelerator.py @@ -77,6 +77,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def sync_tensor(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, diff --git a/pytorch_lightning/accelerators/legacy/ddp2_accelerator.py b/pytorch_lightning/accelerators/legacy/ddp2_accelerator.py index e4712f1270c57..9fa7b7ee2825c 100644 --- a/pytorch_lightning/accelerators/legacy/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/ddp2_accelerator.py @@ -66,6 +66,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/legacy/ddp_accelerator.py b/pytorch_lightning/accelerators/legacy/ddp_accelerator.py index 0899114b147c1..a7abdee146c5a 100644 --- a/pytorch_lightning/accelerators/legacy/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/ddp_accelerator.py @@ -164,6 +164,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/legacy/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/legacy/ddp_cpu_spawn_accelerator.py index 4609ef88c55a4..08649b856bb57 100644 --- a/pytorch_lightning/accelerators/legacy/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/ddp_cpu_spawn_accelerator.py @@ -178,6 +178,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py index f61423583435f..c22eb3e7b5755 100644 --- a/pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py @@ -83,6 +83,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/legacy/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/legacy/ddp_spawn_accelerator.py index c768db3dd16b4..350a78d59f6d3 100644 --- a/pytorch_lightning/accelerators/legacy/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/ddp_spawn_accelerator.py @@ -212,6 +212,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def _step(self, args): args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args) if self.trainer.amp_backend == AMPType.NATIVE: diff --git a/pytorch_lightning/accelerators/legacy/dp_accelerator.py b/pytorch_lightning/accelerators/legacy/dp_accelerator.py index ec2cb54531e4c..cbacb82c80dc0 100644 --- a/pytorch_lightning/accelerators/legacy/dp_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/dp_accelerator.py @@ -132,6 +132,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(args) + def predict(self, args): + return self._step(args) + def training_step_end(self, output): if isinstance(output, Result): output.dp_reduce() diff --git a/pytorch_lightning/accelerators/legacy/gpu_accelerator.py b/pytorch_lightning/accelerators/legacy/gpu_accelerator.py index e1410a2946c7c..1d73919454801 100644 --- a/pytorch_lightning/accelerators/legacy/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/gpu_accelerator.py @@ -85,6 +85,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def to_device(self, batch): gpu_id = 0 if isinstance(self.trainer.data_parallel_device_ids, list): diff --git a/pytorch_lightning/accelerators/legacy/horovod_accelerator.py b/pytorch_lightning/accelerators/legacy/horovod_accelerator.py index 4a15d765b817b..9c9160d7bc15d 100644 --- a/pytorch_lightning/accelerators/legacy/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/horovod_accelerator.py @@ -134,6 +134,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs) optimizer.synchronize() diff --git a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py index 0f4014df04a8a..88b73fe94939f 100644 --- a/pytorch_lightning/accelerators/legacy/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/legacy/tpu_accelerator.py @@ -159,6 +159,9 @@ def validation_step(self, args): def test_step(self, args): return self._step(self.trainer.model.test_step, args) + def predict(self, args): + return self._step(self.trainer.model.predict, args) + def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index cc1f6c6d99cb3..dace4cc33ca99 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -291,10 +291,12 @@ def init_validation_tqdm(self) -> tqdm: ) return bar - def init_test_tqdm(self) -> tqdm: + def init_test_tqdm(self, trainer=None) -> tqdm: """ Override this to customize the tqdm bar for testing. """ + desc = "Testing" + desc = "Predicting" if trainer is not None and getattr(trainer, "is_predicting", False) else "Testing" bar = tqdm( - desc='Testing', + desc=desc, position=(2 * self.process_position), disable=self.is_disabled, leave=True, @@ -361,7 +363,7 @@ def on_train_end(self, trainer, pl_module): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) - self.test_progress_bar = self.init_test_tqdm() + self.test_progress_bar = self.init_test_tqdm(trainer=trainer) self.test_progress_bar.total = convert_inf(self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b9b12d56712db..c453bd5d607d6 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -66,6 +66,7 @@ class LightningModule( "on_gpu", "current_epoch", "global_step", + "running_stage", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): @@ -102,6 +103,7 @@ def __init__(self, *args, **kwargs): self._running_manual_backward = False self._current_hook_fx_name = None self._current_dataloader_idx = None + self.running_stage = None self._automatic_optimization: bool = True def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -982,6 +984,12 @@ def test_epoch_end(self, outputs): self.log('final_metric', final_value) """ + def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None): + """ + Use this function with trainer.predict(...). Override if you need to add any processing logic. + """ + return self(batch) + def configure_optimizers( self, ): diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 69676cf77e079..8ba67aba8b3e4 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -28,6 +28,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.warnings import WarningCache @@ -78,14 +79,22 @@ def forward(self, *inputs, **kwargs): "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: - # lightning - if self.module.training: + + running_stage = self.module.running_stage + + if running_stage == RunningStage.TRAINING: return self.module.training_step(*inputs[0], **kwargs[0]) - if self.module.testing: + + elif running_stage == RunningStage.TESTING: return self.module.test_step(*inputs[0], **kwargs[0]) - return self.module.validation_step(*inputs[0], **kwargs[0]) + elif running_stage == RunningStage.EVALUATING: + return self.module.validation_step(*inputs[0], **kwargs[0]) + + else: + return self.module.predict(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) @@ -187,15 +196,24 @@ def __init__(self, pl_module: LightningModule): self.module = pl_module def forward(self, *inputs, **kwargs): - if self.module.training: + + running_stage = self.module.running_stage + + if running_stage == RunningStage.TRAINING: output = self.module.training_step(*inputs, **kwargs) warn_if_output_is_none(output, "training_step") - elif self.module.testing: + + elif running_stage == RunningStage.TESTING: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") - else: + + elif running_stage == RunningStage.EVALUATING: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") + + else: + output = self.module.predict(*inputs, **kwargs) + return output @@ -276,16 +294,22 @@ def _worker(i, module, input, kwargs, device=None): # --------------- # CHANGE - if module.training: + if module.running_stage == RunningStage.TRAINING: output = module.training_step(*input, **kwargs) fx_called = 'training_step' - elif module.testing: + + elif module.running_stage == RunningStage.TESTING: output = module.test_step(*input, **kwargs) fx_called = 'test_step' - else: + + elif module.running_stage == RunningStage.EVALUATING: output = module.validation_step(*input, **kwargs) fx_called = 'validation_step' + else: + output = module.predict(*input, **kwargs) + fx_called = 'predict' + if output is None: warn_missing_output(fx_called) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 8297d708b26c2..f413065f627ff 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE LightningShardedDataParallel = None @@ -23,10 +24,18 @@ def forward(self, *inputs, **kwargs): if self.enable_broadcast_buffers: self.sync_buffers() - if self.module.training: + running_stage = self.module.running_stage + + if running_stage == RunningStage.TRAINING: outputs = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: + + elif running_stage == RunningStage.TESTING: outputs = self.module.test_step(*inputs, **kwargs) - else: + + elif running_stage == RunningStage.EVALUATING: outputs = self.module.validation_step(*inputs, **kwargs) + + else: + outputs = self.module.predict(*inputs, **kwargs) + return outputs diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 12aa27279aee4..584bdd7772429 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -52,7 +52,7 @@ def __verify_train_loop_configuration(self, model): # verify model has a train dataloader # ----------------------------------- has_train_dataloader = is_overridden('train_dataloader', model) - if not has_train_dataloader: + if not has_train_dataloader and not self.trainer._predicting: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' @@ -62,7 +62,7 @@ def __verify_train_loop_configuration(self, model): # verify model has optimizer # ----------------------------------- has_optimizers = is_overridden('configure_optimizers', model) - if not has_optimizers: + if not has_optimizers and not self.trainer._predicting: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' ' `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.' diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 2e6053d74b55f..62e243e3e3ed5 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -17,12 +17,12 @@ import torch from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DistributedType, LightningEnum class LoggerStages(LightningEnum): """ Train/validation/test phase in each training step. - >>> # you can math the type with string >>> LoggerStages.TRAIN == 'train' True @@ -371,7 +371,7 @@ def update_logger_connector(self) -> None: callback_metrics = {} batch_pbar_metrics = {} batch_log_metrics = {} - is_train = self._stage in LoggerStages.TRAIN.value + is_train = self._stage in RunningStage.TRAINING if not self._has_batch_loop_finished: # get pbar diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 47355c8d097ad..f6700187c3912 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -22,8 +22,9 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator -from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import DeviceType, flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -37,9 +38,9 @@ def __init__(self, trainer): self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] - self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} + self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in RunningStage} + self._cached_results[None] = EpochResultStore(trainer, None) self._callback_hook_validator = CallbackHookNameValidator() - self._current_stage = None @property def callback_metrics(self) -> Dict: @@ -75,7 +76,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self._current_stage) # type: ignore + return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) @@ -90,10 +91,8 @@ def set_metrics(self, key: str, val: Any) -> None: metrics_holder = getattr(self, f"_{key}", None) metrics_holder.reset(val) - def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: - self._current_stage = LoggerStages.determine_stage(stage_or_testing) - if reset: - self.cached_results.reset() + def reset(self) -> None: + self.cached_results.reset() def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None: self._callback_hook_validator.check_logging_in_callbacks( @@ -119,8 +118,7 @@ def on_train_batch_end(self) -> None: self.cached_results._batch_size = None def cache_logged_metrics(self): - if self._current_stage is not None: - self._cached_results[self._current_stage].cache_result() + self._cached_results[self.trainer._running_stage].cache_result() def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool): # logging diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index f2573e25a01e8..2aa6f86dc0b9a 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -15,6 +15,7 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -128,6 +129,7 @@ def setup(self, model, max_batches, dataloaders): self.max_batches = max_batches self.num_dataloaders = self._get_num_dataloaders(dataloaders) + self._predictions = [[] for _ in range(self.num_dataloaders)] def on_evaluation_epoch_start(self, *args, **kwargs): if self.testing: @@ -161,8 +163,17 @@ def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): model_ref = self.trainer.get_model() model_ref._results = Result() - # run actual test step - if self.testing: + + if self.trainer._predicting: + model_ref._current_fx_name = "predict" + predictions = self.trainer.accelerator_backend.predict(args) + self._predictions[dataloader_idx].append(predictions) + self.trainer._progress_bar_callback.on_test_batch_end( + self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx + ) + return + + elif self.testing: model_ref._current_fx_name = "test_step" with self.trainer.profiler.profile("test_step"): output = self.trainer.accelerator_backend.test_step(args) @@ -277,6 +288,18 @@ def __auto_reduce_result_objs(self, outputs): return eval_results + def on_predict_epoch_end(self): + self.trainer._progress_bar_callback.on_test_end(self.trainer, self.trainer.get_model()) + + results = self._predictions + + def _convert_to_numpy(v): + return v.cpu().numpy() + + results = apply_to_collection(results, torch.Tensor, _convert_to_numpy) + + return results, None + def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 6909a80251850..a3ef08df1e49e 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -46,6 +46,7 @@ class RunningStage(LightningEnum): TRAINING = 'train' EVALUATING = 'eval' TESTING = 'test' + PREDICTING = 'predict' TUNING = 'tune' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b97377a150e53..ba34c49581038 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -297,6 +297,7 @@ def __init__( self._device_type = DeviceType.CPU self._distrib_type = None self._running_stage = None + self._predicting = False # init connectors self.dev_debugger = InternalDebugger(self) @@ -444,6 +445,7 @@ def fit( """ # bookkeeping self._state = TrainerState.RUNNING + self._set_wide_running_stage(RunningStage.TRAINING) # ---------------------------- # LINK DATA @@ -478,7 +480,6 @@ def fit( # ---------------------------- # hook self.call_hook('on_fit_start') - results = self.accelerator_backend.train() self.accelerator_backend.teardown() @@ -498,13 +499,39 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED + + self._set_wide_running_stage(None) + return results or 1 + def _set_wide_running_stage(self, stage): + model_ref = self.get_model() + + if stage is None: + self._running_stage = stage + model_ref.running_stage = stage + return + + # todo: clean up this routing mess. + if self._running_stage == RunningStage.TESTING: + stage = RunningStage.TESTING + + # WARNING: With predicting, + # trainer _running_state should be RunningStage.TESTING + # however, the model running_stage should be RunningStage.PREDICTING or None + if model_ref is not None: + if self._predicting: + model_ref.running_stage = RunningStage.PREDICTING + else: + model_ref.running_stage = stage + + self._running_stage = stage + def train(self): self.run_sanity_check(self.get_model()) # set stage for logging - self.logger_connector.set_stage("train") + self._set_wide_running_stage(RunningStage.TRAINING) self.checkpoint_connector.has_trained = False @@ -566,7 +593,8 @@ def train(self): def run_evaluation(self, test_mode: bool = False, max_batches=None): # used to know if we are logging for val, test + reset cached results - self.logger_connector.set_stage(test_mode, reset=True) + self._set_wide_running_stage(RunningStage.TESTING if test_mode else RunningStage.EVALUATING) + self.logger_connector.reset() # bookkeeping self.evaluation_loop.testing = test_mode @@ -616,6 +644,8 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # lightning module methods with self.profiler.profile("evaluation_step_and_end"): output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + if self._predicting: + continue output = self.evaluation_loop.evaluation_step_end(output) # hook + store predictions @@ -630,6 +660,9 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None): # store batch level output per dataloader self.evaluation_loop.outputs.append(dl_outputs) + if self._predicting: + return self.evaluation_loop.on_predict_epoch_end() + # lightning module method deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end() @@ -739,14 +772,14 @@ def test( verbose: If True, prints the test results Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + Returns a list of dictionaries, one for each test dataloader containing their respective metrics. """ # -------------------- # SETUP HOOK # -------------------- self.verbose_test = verbose - self.logger_connector.set_stage("test") + self._set_wide_running_stage(RunningStage.TESTING) # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -764,6 +797,8 @@ def test( self.teardown('test') + self._set_wide_running_stage(None) + return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -832,6 +867,65 @@ def __test_given_model(self, model, test_dataloaders): return results + def predict( + self, + model: Optional[LightningModule] = None, + dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + + Separates from fit to make sure you never run on your predictions set until you want to. + + This will call the model forward function to compute predictions. + + Args: + model: The model to predict on. + + dataloaders: Either a single + Pytorch Dataloader or a list of them, specifying inference samples. + + datamodule: A instance of :class:`LightningDataModule`. + + Returns: + Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. + """ + + # -------------------- + # SETUP HOOK + # -------------------- + # If you supply a datamodule you can't supply dataloaders + if dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass dataloaders to trainer.predict if you supply a datamodule' + ) + + if model is None: + raise MisconfigurationException('You need to pass a model to `trainer.predict`. ') + + if datamodule is not None: + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model, datamodule, 'test') + + # attach data + if dataloaders is not None: + self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders) + + # set path variable + self._predicting = True + os.environ['PL_TESTING_MODE'] = '1' + self.model = model + + results = self.fit(model) + + # unset path variable + self.teardown('test') + del os.environ['PL_TESTING_MODE'] + self._predicting = False + self._set_wide_running_stage(None) + + return results + def tune( self, model: LightningModule, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e4ae2a717e8d5..85c5758ec27be 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn @@ -600,8 +600,9 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation(test_mode=False) + # reset stage to train - self.trainer.logger_connector.set_stage("train") + self.trainer._set_wide_running_stage(RunningStage.TRAINING) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 0320a3dbd0c82..75eb8abc79c04 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -283,8 +283,8 @@ def init_validation_tqdm(self): bar = super().init_validation_tqdm() return self._mock_bar_update(bar) - def init_test_tqdm(self): - bar = super().init_test_tqdm() + def init_test_tqdm(self, trainer=None): + bar = super().init_test_tqdm(trainer=trainer) return self._mock_bar_update(bar) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 26df7fd348cc8..f34c0e196bf85 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -25,7 +25,7 @@ import tests.base.develop_utils as tutils from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState from tests.base import BoringModel, EvalModelTemplate, GenericEvalModelTemplate @@ -393,6 +393,7 @@ def assert_good_acc(): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() + dp_model.module.running_stage = RunningStage.EVALUATING dataloader = trainer.train_dataloader tpipes.run_prediction(dp_model, dataloader, dp=True) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 8c8f1649e73c7..e61b81fd8488e 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -4,6 +4,7 @@ import torch from pytorch_lightning.overrides.data_parallel import LightningDistributedModule +from pytorch_lightning.trainer.states import RunningStage def test_lightning_distributed_module_methods(): @@ -14,18 +15,15 @@ def test_lightning_distributed_module_methods(): batch = torch.rand(5) batch_idx = 3 - pl_module.training = True - pl_module.testing = False + pl_module.running_stage = RunningStage.TRAINING dist_module(batch, batch_idx) pl_module.training_step.assert_called_with(batch, batch_idx) - pl_module.training = False - pl_module.testing = True + pl_module.running_stage = RunningStage.TESTING dist_module(batch, batch_idx) pl_module.test_step.assert_called_with(batch, batch_idx) - pl_module.training = False - pl_module.testing = False + pl_module.running_stage = RunningStage.EVALUATING dist_module(batch, batch_idx) pl_module.validation_step.assert_called_with(batch, batch_idx) @@ -40,16 +38,13 @@ def test_lightning_distributed_module_warn_none_output(): pl_module.test_step.return_value = None with pytest.warns(UserWarning, match="Your training_step returned None"): - pl_module.training = True - pl_module.testing = False + pl_module.running_stage = RunningStage.TRAINING dist_module() with pytest.warns(UserWarning, match="Your test_step returned None"): - pl_module.training = False - pl_module.testing = True + pl_module.running_stage = RunningStage.TESTING dist_module() with pytest.warns(UserWarning, match="Your validation_step returned None"): - pl_module.training = False - pl_module.testing = False + pl_module.running_stage = RunningStage.EVALUATING dist_module() diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 0fff6c00a01c0..ce2a058974df7 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -22,6 +22,8 @@ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_ python ${DEFAULTS} tests/plugins/legacy/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection # python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance +python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_ddp +python ${DEFAULTS} tests/trainer/test_trainer.py::test_trainer_predict_dp python ${DEFAULTS} tests/trainer/logging_/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp python ${DEFAULTS} tests/callbacks/test_pruning.py::test_pruning_callback_ddp python ${DEFAULTS} tests/trainer/test_trainer.py::test_pytorch_profiler_trainer_ddp diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ac124b71db3a4..e21351704fd4c 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,6 +14,7 @@ import math import os import pickle +import platform import sys from argparse import Namespace from copy import deepcopy @@ -27,7 +28,7 @@ from omegaconf import OmegaConf import tests.base.develop_utils as tutils -from pytorch_lightning import Callback, LightningModule, Trainer +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger @@ -37,7 +38,7 @@ from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.base import BoringModel, EvalModelTemplate +from tests.base import BoringModel, EvalModelTemplate, RandomDataset @pytest.fixture @@ -1447,6 +1448,86 @@ def test_trainer_profiler_incorrect_arg_type(profiler): Trainer(profiler=profiler) +class TestLightningDataModule(LightningDataModule): + + def __init__(self, dataloaders): + super().__init__() + self._dataloaders = dataloaders + + def test_dataloader(self): + return self._dataloaders + + +def predict(tmpdir, accelerator, gpus, num_processes, plugins=None, datamodule=True): + + dataloaders = [torch.utils.data.DataLoader(RandomDataset(32, 2)), + torch.utils.data.DataLoader(RandomDataset(32, 2))] + + model = BoringModel() + datamodule = TestLightningDataModule(dataloaders) + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + accelerator=accelerator, + gpus=gpus, + num_processes=num_processes, + plugins=plugins, + num_sanity_val_steps=0 + ) + if datamodule: + results = trainer.predict(model, datamodule=datamodule) + else: + results = trainer.predict(model, dataloaders=dataloaders) + + # todo: address this in another PR + num_samples = 1 if accelerator in ["ddp", "ddp_cpu", "ddp_spawn"] else 2 + assert len(results) == 2 + assert len(results[0]) == num_samples + assert results[0][0].shape == torch.Size([1, 2]) + + +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@pytest.mark.parametrize('datamodule', [False, True]) +def test_trainer_predict_cpu(tmpdir, datamodule): + predict(tmpdir, None, None, 1, datamodule=datamodule) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@pytest.mark.parametrize('num_gpus', [1, 2]) +def test_trainer_predict_dp(tmpdir, num_gpus): + predict(tmpdir, "dp", num_gpus, None) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@pytest.mark.parametrize('plugins', [None, "ddp_sharded"]) +def test_trainer_predict_ddp(tmpdir, plugins): + predict(tmpdir, "ddp", 2, None, plugins=plugins) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +def test_trainer_predict_ddp_spawn(tmpdir): + predict(tmpdir, "ddp_spawn", 2, None) + + +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") +def test_trainer_predict_1_gpu(tmpdir): + predict(tmpdir, None, 1, None) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +def test_trainer_predict_ddp_cpu(tmpdir): + predict(tmpdir, "ddp_cpu", 0, 2) + + def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" with pytorch_profiler.profile("test_step"):