From cd9b21c014b4ef78af7a5da78d152c35554acd15 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 03:10:11 +0200 Subject: [PATCH 01/17] Deprecate passing extras with graphs --- .../connectors/logger_connector/result.py | 14 ++++++++++++-- tests/deprecated_api/test_remove_1-6.py | 15 +++++++++++++++ .../trainer/logging_/test_train_loop_logging.py | 16 ++++++++-------- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index cbc3dcfdefd98..36e1535a0e683 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -26,11 +26,14 @@ from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.metrics import metrics_to_scalars +from pytorch_lightning.utilities.warnings import WarningCache # re-define the ones from pytorch_lightning.utilities.types without the `Number` type _METRIC = Union[Metric, torch.Tensor] _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]] +warning_cache = WarningCache() + class MetricSource(LightningEnum): CALLBACK = "callback" @@ -279,9 +282,16 @@ def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: - raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + # raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') + warning_cache.warn( + f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" + " but this behaviour will change in v1.6. Please detach it manually:" + " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning + ) + return v.detach() + return v - apply_to_collection(extra, torch.Tensor, check_fn) + extra = apply_to_collection(extra, torch.Tensor, check_fn) self['_extra'] = extra def log( diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index cb150cb013ec2..2fa54f3b253fb 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -212,3 +212,18 @@ def test_v1_6_0_early_stopping_monitor(tmpdir): " For backward compatibility, setting this to `early_stop_on`." ): EarlyStopping() + + +def test_v1_6_0_extras_with_gradients(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, *args): + loss = super().training_step(*args)['loss'] + return {"loss": loss, 'foo': loss} + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) + model = TestModel() + match = r"\{'foo'\} has a `grad_fn`.*behaviour will change in v1\.6" + with pytest.deprecated_call(match=match): + trainer.fit(model) diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index bff558e81b29e..861cded01b8cf 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -690,16 +690,16 @@ def training_step(self, batch, batch_idx): with pytest.raises(MisconfigurationException, match='`self.log` with the key `foo/dataloader_idx_0`'): trainer.fit(model) - class TestModel(BoringModel): + # class TestModel(BoringModel): - def training_step(self, *args): - loss = super().training_step(*args)['loss'] - return {"loss": loss, 'foo': loss} + # def training_step(self, *args): + # loss = super().training_step(*args)['loss'] + # return {"loss": loss, 'foo': loss} - trainer = Trainer(default_root_dir=tmpdir) - model = TestModel() - with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): - trainer.fit(model) + # trainer = Trainer(default_root_dir=tmpdir) + # model = TestModel() + # with pytest.raises(MisconfigurationException, match='You returned a tensor with `grad_fn`'): + # trainer.fit(model) class TestModel(BoringModel): From 8092946af6ee63aa1b477205cb6bfed12065e323 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 03:15:04 +0200 Subject: [PATCH 02/17] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddaf4288a0202..5575f7aeaf3f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `is_overridden(model=...)` in favor of `is_overridden(instance=...)` ([#7918](https://github.com/PyTorchLightning/pytorch-lightning/pull/7918)) +- Deprecated automatically detaching returned extras with grads ([#7994](https://github.com/PyTorchLightning/pytorch-lightning/pull/7994)) + + - Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907)) From 58142139bcf3c74056fdae24695e4bc2788d98e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 12:59:44 +0200 Subject: [PATCH 03/17] remove DeprecationWarning class to make the warning show up in sterr. --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 36e1535a0e683..01b5be8bf3359 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -286,7 +286,7 @@ def check_fn(v): warning_cache.warn( f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" " but this behaviour will change in v1.6. Please detach it manually:" - " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning + " `return {'loss': ..., 'something': something.detach()}`", ) return v.detach() return v From 47bd54c0c2cc19657d682871474aec4c6987b2c8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 14:02:54 +0200 Subject: [PATCH 04/17] Revert "remove DeprecationWarning class to make the warning show up in sterr." This reverts commit 58142139bcf3c74056fdae24695e4bc2788d98e4. --- pytorch_lightning/trainer/connectors/logger_connector/result.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 01b5be8bf3359..36e1535a0e683 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -286,7 +286,7 @@ def check_fn(v): warning_cache.warn( f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" " but this behaviour will change in v1.6. Please detach it manually:" - " `return {'loss': ..., 'something': something.detach()}`", + " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning ) return v.detach() return v From 3dc65394006d3c6420ec2e40cf526e3ed58e827e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 14:51:40 +0200 Subject: [PATCH 05/17] Fix deprecation warnings --- .../callbacks/model_checkpoint.py | 4 +- pytorch_lightning/core/datamodule.py | 2 +- pytorch_lightning/core/grads.py | 2 +- pytorch_lightning/core/lightning.py | 7 ++-- pytorch_lightning/loggers/csv_logs.py | 3 +- pytorch_lightning/loggers/test_tube.py | 4 +- pytorch_lightning/loggers/wandb.py | 4 +- .../loops/training_batch_loop.py | 8 ++-- .../loops/training_epoch_loop.py | 4 +- .../plugins/training_type/ddp_spawn.py | 13 +++--- .../plugins/training_type/deepspeed.py | 7 ++-- pytorch_lightning/profiler/base.py | 5 +-- pytorch_lightning/profiler/pytorch.py | 6 +-- pytorch_lightning/trainer/callback_hook.py | 4 +- .../connectors/accelerator_connector.py | 4 +- .../connectors/logger_connector/result.py | 4 +- .../connectors/training_trick_connector.py | 3 +- pytorch_lightning/trainer/logging.py | 2 +- pytorch_lightning/trainer/model_hooks.py | 2 +- pytorch_lightning/trainer/training_loop.py | 8 ++-- pytorch_lightning/utilities/__init__.py | 9 +--- pytorch_lightning/utilities/device_parser.py | 5 +-- pytorch_lightning/utilities/distributed.py | 24 +++++------ pytorch_lightning/utilities/warnings.py | 41 ++++++++++++++----- 24 files changed, 96 insertions(+), 79 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 067ebfdeafbe7..0d1132f191652 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -650,10 +650,10 @@ def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None: self.save_top_k = 1 if deprecation_warning: - warning_cache.warn( + warning_cache.deprecation( "Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2" " and will be removed in v1.4. Please, create your own `mc = ModelCheckpoint(monitor='your_monitor')`" - " and use it as `Trainer(callbacks=[mc])`.", DeprecationWarning + " and use it as `Trainer(callbacks=[mc])`.", ) def _validate_monitor_key(self, trainer: 'pl.Trainer') -> None: diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 9dd8066f15080..df3fa26a24a17 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -20,8 +20,8 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.distributed import rank_zero_deprecation class LightningDataModule(CheckpointHooks, DataHooks): diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index 30a2f0ae7e38f..f6a0d41035460 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -18,7 +18,7 @@ from torch.nn import Module -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.grads import grad_norm as new_grad_norm diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bc070b25e7b4e..663b0e470664b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -170,7 +170,8 @@ def example_input_array(self, example: Any) -> None: def datamodule(self) -> Any: rank_zero_deprecation( "The `LightningModule.datamodule` property is deprecated in v1.3 and will be removed in v1.5." - " Access the datamodule through using `self.trainer.datamodule` instead." + " Access the datamodule through using `self.trainer.datamodule` instead.", + stacklevel=5, ) return self._datamodule @@ -223,10 +224,10 @@ def _apply_batch_transfer_handler( if is_param_in_hook_signature(self.transfer_batch_to_device, 'dataloader_idx'): batch = self.transfer_batch_to_device(batch, device, dataloader_idx) else: - warning_cache.warn( + warning_cache.deprecation( "`transfer_batch_to_device` hook signature has changed in v1.4." " `dataloader_idx` parameter has been added to it. Support for" - " the old signature will be removed in v1.6", DeprecationWarning + " the old signature will be removed in v1.6" ) batch = self.transfer_batch_to_device(batch, device) diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 4df672fa6e3b5..754a7cf892060 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -29,7 +29,8 @@ from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only log = logging.getLogger(__name__) diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 84f231b0f16d7..1107a0bcb2c4c 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -20,8 +20,8 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment -from pytorch_lightning.utilities import _module_available -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import _module_available, rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_only _TESTTUBE_AVAILABLE = _module_available("test_tube") diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index c127fa037ed6b..f26e5b7783c9b 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -137,9 +137,9 @@ def __init__( ) if sync_step is not None: - warning_cache.warn( + warning_cache.deprecation( "`WandbLogger(sync_step=(True|False))` is deprecated in v1.2.1 and will be removed in v1.5." - " Metrics are now logged separately and automatically synchronized.", DeprecationWarning + " Metrics are now logged separately and automatically synchronized.", ) super().__init__() diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index b581c6c8c1384..39806e46fcbd6 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -488,10 +488,10 @@ def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Te if len(self.trainer.optimizers) > 1: if self.trainer.has_arg("training_step", "optimizer_idx"): if not self.trainer.lightning_module.automatic_optimization: - self.warning_cache.warn( + self.warning_cache.deprecation( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5", DeprecationWarning + " the old signature will be removed in v1.5" ) args.append(opt_idx) elif not self.trainer.has_arg( @@ -682,10 +682,10 @@ def _build_kwargs(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Optio has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: - self.warning_cache.warn( + self.warning_cache.deprecation( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5", DeprecationWarning + " the old signature will be removed in v1.5" ) step_kwargs['optimizer_idx'] = opt_idx elif not has_opt_idx_in_train_step and lightning_module.automatic_optimization: diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index d029c525d71ac..67fc2e2a6f72c 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -231,10 +231,10 @@ def _on_train_epoch_end_hook(self, processed_epoch_output: List[List[STEP_OUTPUT if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): - self.warning_cache.warn( + self.warning_cache.deprecation( "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." - " Support for the old signature will be removed in v1.5", DeprecationWarning + " Support for the old signature will be removed in v1.5", ) model_ref.on_train_epoch_end(processed_epoch_output) else: diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 8d2cc217835fb..c8ff9298a56bb 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -28,16 +28,15 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8 -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load -from pytorch_lightning.utilities.distributed import ( +from pytorch_lightning.utilities import ( + _TORCH_GREATER_EQUAL_1_7, + _TORCH_GREATER_EQUAL_1_8, rank_zero_deprecation, - rank_zero_only, rank_zero_warn, - ReduceOp, - sync_ddp_if_available, ) +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available from pytorch_lightning.utilities.seed import reset_seed if _TORCH_GREATER_EQUAL_1_8: diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8f613081cdfe2..c57e715eccf91 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -15,7 +15,6 @@ import json import logging import os -import warnings from collections import OrderedDict from pathlib import Path from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union @@ -33,6 +32,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning if _DEEPSPEED_AVAILABLE: import deepspeed @@ -260,10 +260,11 @@ def __init__( ) if cpu_offload or cpu_offload_params or cpu_offload_use_pin_memory: - warnings.warn( + _warn( "The usage of `cpu_offload`, `cpu_offload_params`, and `cpu_offload_use_pin_memory` " "is deprecated since v1.4 and will be removed in v1.5." - " From now on use `offload_optimizer`, `offload_parameters` and `pin_memory`.", DeprecationWarning + " From now on use `offload_optimizer`, `offload_parameters` and `pin_memory`.", + category=LightningDeprecationWarning ) offload_optimizer = cpu_offload offload_parameters = cpu_offload_params diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py index 2a064085e8da7..d327d34e8d8c6 100644 --- a/pytorch_lightning/profiler/base.py +++ b/pytorch_lightning/profiler/base.py @@ -19,7 +19,7 @@ from pathlib import Path from typing import Any, Callable, Dict, Optional, TextIO, Union -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.cloud_io import get_filesystem log = logging.getLogger(__name__) @@ -63,10 +63,9 @@ def __init__( self.dirpath = dirpath self.filename = filename if output_filename is not None: - rank_zero_warn( + rank_zero_deprecation( "`Profiler` signature has changed in v1.3. The `output_filename` parameter has been removed in" " favor of `dirpath` and `filename`. Support for the old signature will be removed in v1.5", - DeprecationWarning ) filepath = Path(output_filename) self.dirpath = filepath.parent diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index b78922d7f4a47..533f592c06182 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -24,7 +24,7 @@ from torch.autograd.profiler import record_function from pytorch_lightning.profiler.base import BaseProfiler -from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE @@ -349,9 +349,9 @@ def __deprecation_check( record_functions = set() if profiled_functions is not None: - rank_zero_warn( + rank_zero_deprecation( "`PyTorchProfiler.profiled_functions` has been renamed to" - " `record_functions` in v1.3 and will be removed in v1.5", DeprecationWarning + " `record_functions` in v1.3 and will be removed in v1.5" ) if not record_functions: record_functions |= set(profiled_functions) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 23df26b410a03..c4188c070e8b5 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -97,10 +97,10 @@ def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): """ for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): - warning_cache.warn( + warning_cache.deprecation( "The signature of `Callback.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been removed." - " Support for the old signature will be removed in v1.5", DeprecationWarning + " Support for the old signature will be removed in v1.5", ) callback.on_train_epoch_end(self, self.lightning_module, outputs) else: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 8f5de9a6302aa..297e83330b2cb 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -67,8 +67,10 @@ device_parser, DeviceType, DistributedType, + rank_zero_deprecation, + rank_zero_info, + rank_zero_warn, ) -from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException if _HOROVOD_AVAILABLE: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 36e1535a0e683..fc8a04fd38520 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -283,10 +283,10 @@ def extra(self, extra: Mapping[str, Any]) -> None: def check_fn(v): if v.grad_fn is not None: # raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}') - warning_cache.warn( + warning_cache.deprecation( f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically" " but this behaviour will change in v1.6. Please detach it manually:" - " `return {'loss': ..., 'something': something.detach()}`", DeprecationWarning + " `return {'loss': ..., 'something': something.detach()}`", ) return v.detach() return v diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index f27288d2b13f4..4d93fa5977d13 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -14,8 +14,7 @@ from typing import Dict, List, Optional, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler -from pytorch_lightning.utilities import GradClipAlgorithmType -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 0a59b9d8d4c36..74603782f3293 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -14,7 +14,7 @@ from abc import ABC -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.metrics import metrics_to_scalars as new_metrics_to_scalars diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 86cb1334a7067..cbf331913e597 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -16,7 +16,7 @@ from typing import Optional from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f76568454b7ac..26b4fa0a1fbb3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -580,10 +580,10 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: if is_overridden(hook_name, model_ref): hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): - self.warning_cache.warn( + self.warning_cache.deprecation( "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." - " Support for the old signature will be removed in v1.5", DeprecationWarning + " Support for the old signature will be removed in v1.5", ) model_ref.on_train_epoch_end(processed_epoch_output) else: @@ -866,10 +866,10 @@ def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") if has_opt_idx_in_train_step: if not lightning_module.automatic_optimization: - self.warning_cache.warn( + self.warning_cache.deprecation( "`training_step` hook signature has changed in v1.3." " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5", DeprecationWarning + " the old signature will be removed in v1.5", ) step_kwargs['optimizer_idx'] = opt_idx elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 613a5013d5198..c2e727d314396 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -16,13 +16,7 @@ import numpy from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 -from pytorch_lightning.utilities.distributed import ( # noqa: F401 - AllGatherGrad, - rank_zero_deprecation, - rank_zero_info, - rank_zero_only, - rank_zero_warn, -) +from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only # noqa: F401 from pytorch_lightning.utilities.enums import ( # noqa: F401 AMPType, DeviceType, @@ -63,6 +57,7 @@ _XLA_AVAILABLE, ) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn # noqa: F401 FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 511a91326953d..ecb5d6ac00a03 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -16,7 +16,7 @@ import torch -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version @@ -121,12 +121,11 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in else: num_gpus = int(s.strip()) if _compare_version("pytorch_lightning", operator.lt, "1.5"): - rank_zero_warn( + rank_zero_deprecation( f"Parsing of the Trainer argument gpus='{s}' (string) will change in the future." " In the current version of Lightning, this will select" f" CUDA device with index {num_gpus}, but from v1.5 it will select gpus" f" {list(range(num_gpus))} (same as gpus={s} (int)).", - DeprecationWarning, ) return [num_gpus] return num_gpus diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a507afa6bc895..d4511d9f0279d 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -14,13 +14,13 @@ import logging import os -import warnings -from functools import partial, wraps +from functools import wraps from typing import Any, Optional, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE if _TPU_AVAILABLE: @@ -65,22 +65,22 @@ def _get_rank() -> int: rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank()) -def _warn(*args, **kwargs): - warnings.warn(*args, **kwargs) +def _info(*args, stacklevel: int = 2, **kwargs): + log.info(*args, stacklevel=stacklevel, **kwargs) -def _info(*args, **kwargs): - log.info(*args, **kwargs) +def _debug(*args, stacklevel: int = 2, **kwargs): + log.debug(*args, stacklevel=stacklevel, **kwargs) -def _debug(*args, **kwargs): - log.debug(*args, **kwargs) +@rank_zero_only +def rank_zero_debug(*args, stacklevel: int = 4, **kwargs): + _debug(*args, stacklevel=stacklevel, **kwargs) -rank_zero_debug = rank_zero_only(_debug) -rank_zero_info = rank_zero_only(_info) -rank_zero_warn = rank_zero_only(_warn) -rank_zero_deprecation = partial(rank_zero_warn, category=DeprecationWarning) +@rank_zero_only +def rank_zero_info(*args, stacklevel: int = 4, **kwargs): + _info(*args, stacklevel=stacklevel, **kwargs) def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index a3dde95fa928f..865b8a8313e9b 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -11,18 +11,39 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.distributed import rank_zero_warn +import warnings +from functools import partial +from pytorch_lightning.utilities.distributed import rank_zero_only -class WarningCache: - def __init__(self): - self.warnings = set() +def _warn(*args, stacklevel: int = 2, **kwargs): + warnings.warn(*args, stacklevel=stacklevel, **kwargs) - def warn(self, m, *args, **kwargs): - if m not in self.warnings: - self.warnings.add(m) - rank_zero_warn(m, *args, **kwargs) - def clear(self): - self.warnings.clear() +@rank_zero_only +def rank_zero_warn(*args, stacklevel: int = 4, **kwargs): + _warn(*args, stacklevel=stacklevel, **kwargs) + + +class LightningDeprecationWarning(DeprecationWarning): + ... + + +# enable our warnings +warnings.simplefilter('default', LightningDeprecationWarning) + +rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning) + + +class WarningCache(set): + + def warn(self, m, *args, stacklevel: int = 5, **kwargs): + if m not in self: + self.add(m) + rank_zero_warn(m, *args, stacklevel=stacklevel, **kwargs) + + def deprecation(self, m, *args, stacklevel: int = 5, **kwargs): + if m not in self: + self.add(m) + rank_zero_deprecation(m, *args, stacklevel=stacklevel, **kwargs) From f01b3b260e67f562a9253fde824c838d4469d984 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 14:58:58 +0200 Subject: [PATCH 06/17] Fix imports --- pytorch_lightning/utilities/distributed.py | 2 +- pytorch_lightning/utilities/parsing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index d4511d9f0279d..db3ccb18229b2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -20,7 +20,6 @@ import torch from torch.nn.parallel.distributed import DistributedDataParallel -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE if _TPU_AVAILABLE: @@ -294,6 +293,7 @@ def register_ddp_comm_hook( ddp_comm_wrapper=default.fp16_compress_wrapper, ) """ + from pytorch_lightning.utilities import rank_zero_warn if not _TORCH_GREATER_EQUAL_1_8: rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.") return diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 6141a80b5f97c..a12af8994f6f6 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -18,7 +18,7 @@ from argparse import Namespace from typing import Any, Dict, Optional, Sequence, Tuple, Union -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.warnings import rank_zero_warn def str_to_bool_or_str(val: str) -> Union[str, bool]: @@ -97,7 +97,7 @@ def clean_namespace(hparams): del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)] for k in del_attrs: - rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled", UserWarning) + rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled") del hparams_dict[k] From 9c19b49fa18bc7897132ae359ae48ab24acb3e5c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 16 Jun 2021 15:36:38 +0200 Subject: [PATCH 07/17] Fix import --- pytorch_lightning/profiler/profilers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/profiler/profilers.py b/pytorch_lightning/profiler/profilers.py index c97dab0c8968b..3f534ce0bb425 100644 --- a/pytorch_lightning/profiler/profilers.py +++ b/pytorch_lightning/profiler/profilers.py @@ -1,7 +1,7 @@ -from pytorch_lightning.utilities.distributed import rank_zero_deprecation +from pytorch_lightning.utilities import rank_zero_deprecation rank_zero_deprecation( - "Using ``import pytorch_lightning.profiler.profilers`` is depreceated in v1.4, and will be removed in v1.6. " + "Using ``import pytorch_lightning.profiler.profilers`` is deprecated in v1.4, and will be removed in v1.6. " "HINT: Use ``import pytorch_lightning.profiler`` directly." ) From 6db52c3b753d1b57808cbe28049e715f8dc36438 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 14:14:43 +0200 Subject: [PATCH 08/17] Bad merge --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index cd8bb31b319c2..c57e715eccf91 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -29,7 +29,7 @@ from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import _warn, rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.warnings import _warn, LightningDeprecationWarning From 86de6bbb23c728d79567166db1fc8ec9bcfbc48c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 14:22:22 +0200 Subject: [PATCH 09/17] Fix wrong fix --- pytorch_lightning/core/lightning.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index df9a3a0d362f1..6af8e8b308ea4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -171,7 +171,7 @@ def datamodule(self) -> Any: warning_cache.deprecation( "The `LightningModule.datamodule` property is deprecated in v1.3 and will be removed in v1.5." " Access the datamodule through using `self.trainer.datamodule` instead.", - stacklevel=5, + stacklevel=6, ) return self._datamodule diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index d4df7f2e65034..62020b62a4768 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -369,10 +369,8 @@ def test_v1_5_0_datamodule_setter(): datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule - from pytorch_lightning.core.lightning import warning_cache - warning_cache.clear() - _ = model.datamodule - assert any("The `LightningModule.datamodule`" in w for w in warning_cache) + with pytest.deprecated_call(match="The `LightningModule.datamodule`"): + _ = model.datamodule def test_v1_5_0_trainer_tbptt_steps(tmpdir): From 799758a91d9ac8e350048514fb7471479fe6a698 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 14:34:39 +0200 Subject: [PATCH 10/17] Avoid using deprecated model.datamodule --- .../trainer/connectors/data_connector.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 4ff7e5aa21a42..c21238f06fe8f 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -113,30 +113,28 @@ def attach_dataloaders( def attach_datamodule( self, model: 'pl.LightningModule', datamodule: Optional['pl.LightningDataModule'] = None ) -> None: - # We use datamodule if it's been provided, otherwise we check model for it - datamodule = datamodule or getattr(model, 'datamodule', None) - # If we have a datamodule, attach necessary hooks + dataloaders - if datamodule: - - # Override loader hooks - dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') - for method in dl_methods: - if is_overridden(method, datamodule): - setattr(model, method, getattr(datamodule, method)) - - # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule - batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') - for hook in batch_transfer_hooks: - if is_overridden(hook, datamodule): - setattr(model, hook, getattr(datamodule, hook)) - - self.trainer.datamodule = datamodule - datamodule.trainer = self.trainer - - # experimental feature for Flash - if hasattr(datamodule, "data_pipeline"): - model.data_pipeline = datamodule.data_pipeline + if datamodule is None: + return + + # Override loader hooks + dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader') + for method in dl_methods: + if is_overridden(method, datamodule): + setattr(model, method, getattr(datamodule, method)) + + # Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule + batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer') + for hook in batch_transfer_hooks: + if is_overridden(hook, datamodule): + setattr(model, hook, getattr(datamodule, hook)) + + self.trainer.datamodule = datamodule + datamodule.trainer = self.trainer + + # experimental feature for Flash + if hasattr(datamodule, "data_pipeline"): + model.data_pipeline = datamodule.data_pipeline class _PatchDataLoader: From 59b6dee16498d94fbcec90f096d8c3b834bdaeaa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 15:25:44 +0200 Subject: [PATCH 11/17] Only on >= 3.8 --- pytorch_lightning/utilities/distributed.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index db3ccb18229b2..eb35cdec6c13d 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -15,6 +15,7 @@ import logging import os from functools import wraps +from platform import python_version from typing import Any, Optional, Union import torch @@ -65,11 +66,15 @@ def _get_rank() -> int: def _info(*args, stacklevel: int = 2, **kwargs): - log.info(*args, stacklevel=stacklevel, **kwargs) + if python_version() >= "3.8.0": + kwargs['stacklevel'] = stacklevel + log.info(*args, **kwargs) def _debug(*args, stacklevel: int = 2, **kwargs): - log.debug(*args, stacklevel=stacklevel, **kwargs) + if python_version() >= "3.8.0": + kwargs['stacklevel'] = stacklevel + log.debug(*args, **kwargs) @rank_zero_only From 04aa536e9a19cec08394b301abd5c7f384e09769 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 15:46:38 +0200 Subject: [PATCH 12/17] Revert --- tests/deprecated_api/test_remove_1-5.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 62020b62a4768..d4df7f2e65034 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -369,8 +369,10 @@ def test_v1_5_0_datamodule_setter(): datamodule = BoringDataModule() with no_deprecated_call(match="The `LightningModule.datamodule`"): model.datamodule = datamodule - with pytest.deprecated_call(match="The `LightningModule.datamodule`"): - _ = model.datamodule + from pytorch_lightning.core.lightning import warning_cache + warning_cache.clear() + _ = model.datamodule + assert any("The `LightningModule.datamodule`" in w for w in warning_cache) def test_v1_5_0_trainer_tbptt_steps(tmpdir): From 0fc7187cc34d4b0cc9dace3d864a2a032fb7bb42 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Jun 2021 17:22:02 +0200 Subject: [PATCH 13/17] Add test --- tests/special_tests.sh | 6 ++++ tests/utilities/test_warnings.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 tests/utilities/test_warnings.py diff --git a/tests/special_tests.sh b/tests/special_tests.sh index b6de1ca69ecef..9fca3b62bad40 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -72,6 +72,12 @@ if nvcc --version; then nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} tests/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx fi +# needs to run outside of `pytest` +python tests/utilities/test_warnings.py +if [ $? -eq 0 ]; then + report+="Ran\ttests/utilities/test_warnings.py\n" +fi + # echo test report printf '=%.s' {1..80} printf "\n$report" diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py new file mode 100644 index 0000000000000..a5cd2f3f49eeb --- /dev/null +++ b/tests/utilities/test_warnings.py @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test that the warnings actually appear and they have the correct `stacklevel` + +Needs to be run outside of `pytest` as it captures all the warnings. +""" +import os +from contextlib import redirect_stderr +from io import StringIO + +from pytorch_lightning.utilities.warnings import _warn, rank_zero_deprecation, rank_zero_warn, WarningCache + +running_special = os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1" +if running_special: + + stderr = StringIO() + with redirect_stderr(stderr): + _warn("test1") + _warn("test2", DeprecationWarning) + + rank_zero_warn("test3") + rank_zero_warn("test4", DeprecationWarning) + + rank_zero_deprecation("test5") + + cache = WarningCache() + cache.warn("test6") + cache.deprecation("test7") + + output = stderr.getvalue() + assert "test_warnings.py:31: UserWarning: test1" in output + assert "test_warnings.py:32: DeprecationWarning: test2" in output + + assert "test_warnings.py:34: UserWarning: test3" in output + assert "test_warnings.py:35: DeprecationWarning: test4" in output + + assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output + + assert "test_warnings.py:40: UserWarning: test6" in output + assert "test_warnings.py:41: LightningDeprecationWarning: test7" in output From 64dccc991df088e242952357fa210a02802efd10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 21 Jun 2021 12:44:26 +0200 Subject: [PATCH 14/17] Update tests/utilities/test_warnings.py --- tests/utilities/test_warnings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index a5cd2f3f49eeb..e981f11fe2b5c 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -1,4 +1,3 @@ -#!/bin/bash # Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); From 015a7b886a3b91957a1298f30aad85ee9cf39046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 21 Jun 2021 13:15:11 +0200 Subject: [PATCH 15/17] Update tests/utilities/test_warnings.py --- tests/utilities/test_warnings.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index e981f11fe2b5c..2e0c372e5c39f 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -40,13 +40,13 @@ cache.deprecation("test7") output = stderr.getvalue() - assert "test_warnings.py:31: UserWarning: test1" in output - assert "test_warnings.py:32: DeprecationWarning: test2" in output + assert "test_warnings.py:30: UserWarning: test1" in output + assert "test_warnings.py:31: DeprecationWarning: test2" in output - assert "test_warnings.py:34: UserWarning: test3" in output - assert "test_warnings.py:35: DeprecationWarning: test4" in output + assert "test_warnings.py:33: UserWarning: test3" in output + assert "test_warnings.py:34: DeprecationWarning: test4" in output - assert "test_warnings.py:37: LightningDeprecationWarning: test5" in output + assert "test_warnings.py:36: LightningDeprecationWarning: test5" in output - assert "test_warnings.py:40: UserWarning: test6" in output - assert "test_warnings.py:41: LightningDeprecationWarning: test7" in output + assert "test_warnings.py:39: UserWarning: test6" in output + assert "test_warnings.py:40: LightningDeprecationWarning: test7" in output From 8815896a377438931ceb4a96b5c909abccb562fb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 15:43:37 +0200 Subject: [PATCH 16/17] Install package for IPU testing --- .azure-pipelines/ipu-tests.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.azure-pipelines/ipu-tests.yml b/.azure-pipelines/ipu-tests.yml index c1474ee1c9187..065c6983a2abe 100644 --- a/.azure-pipelines/ipu-tests.yml +++ b/.azure-pipelines/ipu-tests.yml @@ -53,12 +53,9 @@ jobs: export GIT_TERMINAL_PROMPT=1 python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'fairscale' not in line] ; open(fname, 'w').writelines(lines)" python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" - python ./requirements/adjust_versions.py requirements/extra.txt python ./requirements/adjust_versions.py requirements/examples.txt - - pip install --requirement ./requirements/devel.txt --upgrade-strategy only-if-needed - + pip install . --requirement requirements/devel.txt pip list displayName: 'Install dependencies' From fa7477919f0342770014acfb967ed98305df4235 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 21 Jun 2021 18:27:16 +0200 Subject: [PATCH 17/17] Add docstring --- pytorch_lightning/utilities/warnings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 865b8a8313e9b..0595a41ea5aa0 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Warning-related utilities""" import warnings from functools import partial