From dbba7543c59498de512a0a7161c2f20cec39bcaa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 26 Oct 2021 17:55:22 +0200 Subject: [PATCH] Replace `_TORCH_GREATER_EQUAL_DEV_1_10` with `_TORCH_GREATER_EQUAL_1_10` --- docs/source/advanced/mixed_precision.rst | 4 ++-- docs/source/conf.py | 2 +- pytorch_lightning/callbacks/quantization.py | 6 +++--- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/plugins/precision/native_amp.py | 8 ++++---- pytorch_lightning/utilities/__init__.py | 1 - pytorch_lightning/utilities/imports.py | 2 +- tests/core/test_lightning_module.py | 6 +----- tests/models/test_amp.py | 6 ++---- tests/plugins/test_amp_plugins.py | 5 ++--- 10 files changed, 18 insertions(+), 26 deletions(-) diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst index 1c98f663ed5f3..9889c05db243d 100644 --- a/docs/source/advanced/mixed_precision.rst +++ b/docs/source/advanced/mixed_precision.rst @@ -50,14 +50,14 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision. .. testcode:: - :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available() + :skipif: not _TORCH_GREATER_EQUAL_1_10 or not torch.cuda.is_available() Trainer(gpus=1, precision="bf16") It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood. .. testcode:: - :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 + :skipif: not _TORCH_GREATER_EQUAL_1_10 Trainer(precision="bf16") diff --git a/docs/source/conf.py b/docs/source/conf.py index f5f9605263217..16b2ed7509ee3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -373,7 +373,7 @@ def package_list_from_file(file): _XLA_AVAILABLE, _TPU_AVAILABLE, _TORCHVISION_AVAILABLE, - _TORCH_GREATER_EQUAL_DEV_1_10, + _TORCH_GREATER_EQUAL_1_10, _module_available, ) _JSONARGPARSE_AVAILABLE = _module_available("jsonargparse") diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index bf0088575e8b4..ca82a574f71d1 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -33,10 +33,10 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10 +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TORCH_GREATER_EQUAL_DEV_1_10: +if _TORCH_GREATER_EQUAL_1_10: from torch.ao.quantization.qconfig import QConfig else: from torch.quantization import QConfig @@ -245,7 +245,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - # version=None corresponds to using FakeQuantize rather than # FusedMovingAvgObsFakeQuantize which was introduced in PT1.10 # details in https://github.com/pytorch/pytorch/issues/64564 - extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_DEV_1_10 else {} + extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {} pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) elif isinstance(self._qconfig, QConfig): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7a58f91adda7d..cfac84be1367b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -39,7 +39,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( _IS_WINDOWS, - _TORCH_GREATER_EQUAL_DEV_1_10, + _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, rank_zero_deprecation, rank_zero_warn, @@ -2043,7 +2043,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS: + if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS: return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 3fc903cbb3fce..487d80005c222 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -21,10 +21,10 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10, AMPType +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TORCH_GREATER_EQUAL_DEV_1_10: +if _TORCH_GREATER_EQUAL_1_10: from torch import autocast else: from torch.cuda.amp import autocast @@ -47,7 +47,7 @@ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> No def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype: if precision == "bf16": - if not _TORCH_GREATER_EQUAL_DEV_1_10: + if not _TORCH_GREATER_EQUAL_1_10: raise MisconfigurationException( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) @@ -97,7 +97,7 @@ def optimizer_step( self.scaler.update() def autocast_context_manager(self) -> autocast: - if _TORCH_GREATER_EQUAL_DEV_1_10: + if _TORCH_GREATER_EQUAL_1_10: return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype) return autocast() diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index bc19aa1366a55..158d7356c91ce 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -48,7 +48,6 @@ _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_10, - _TORCH_GREATER_EQUAL_DEV_1_10, _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index c7ad70895672a..811e81a370601 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -75,7 +75,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1") _TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") -_TORCH_GREATER_EQUAL_DEV_1_10 = _compare_version("torch", operator.ge, "1.10.0", use_base_version=True) +# _TORCH_GREATER_EQUAL_DEV_1_11 = _compare_version("torch", operator.ge, "1.11.0", use_base_version=True) _APEX_AVAILABLE = _module_available("apex.amp") _DEEPSPEED_AVAILABLE = _module_available("deepspeed") diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index d661228ee09d8..ff8ffa3c50acd 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -21,7 +21,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -312,10 +311,7 @@ def __init__(self, spec): self.sharded_tensor.local_shards()[0].tensor.fill_(0) -@pytest.mark.skipif( - not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`" -) -@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows") +@RunIf(min_torch="1.10", skip_windows=True) def test_sharded_tensor_state_dict(tmpdir, single_process_pg): spec = dist._sharding_spec.ChunkShardingSpec( dim=0, diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 716c0f17f203d..86863238da057 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -22,7 +22,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10 from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -68,7 +67,7 @@ def _assert_autocast_enabled(self): assert torch.is_autocast_enabled() -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support") +@RunIf(min_torch="1.10") @pytest.mark.parametrize( "strategy", [ @@ -95,8 +94,7 @@ def test_amp_cpus(tmpdir, strategy, precision, num_processes): assert trainer.state.finished, f"Training failed with {trainer.state}" -@RunIf(min_gpus=2) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support") +@RunIf(min_gpus=2, min_torch="1.10") @pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"]) @pytest.mark.parametrize("precision", [16, "bf16"]) @pytest.mark.parametrize("gpus", [1, 2]) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index ed8c653b3a78f..227d898a7da40 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -21,7 +21,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -178,7 +177,7 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): trainer.fit(model) -@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Torch CPU AMP is not available.") +@RunIf(min_torch="1.10") def test_cpu_amp_precision_context_manager(tmpdir): """Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set.""" plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True) @@ -197,7 +196,7 @@ def test_precision_selection_raises(monkeypatch): import pytorch_lightning.plugins.precision.native_amp as amp - monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_DEV_1_10", False) + monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_1_10", False) with pytest.warns( UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16" ), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):