Skip to content

Commit dbba754

Browse files
committed
Replace _TORCH_GREATER_EQUAL_DEV_1_10 with _TORCH_GREATER_EQUAL_1_10
1 parent c33df26 commit dbba754

File tree

10 files changed

+18
-26
lines changed

10 files changed

+18
-26
lines changed

docs/source/advanced/mixed_precision.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
5050
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
5151

5252
.. testcode::
53-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
53+
:skipif: not _TORCH_GREATER_EQUAL_1_10 or not torch.cuda.is_available()
5454

5555
Trainer(gpus=1, precision="bf16")
5656

5757
It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN under the hood.
5858

5959
.. testcode::
60-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
60+
:skipif: not _TORCH_GREATER_EQUAL_1_10
6161

6262
Trainer(precision="bf16")
6363

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def package_list_from_file(file):
373373
_XLA_AVAILABLE,
374374
_TPU_AVAILABLE,
375375
_TORCHVISION_AVAILABLE,
376-
_TORCH_GREATER_EQUAL_DEV_1_10,
376+
_TORCH_GREATER_EQUAL_1_10,
377377
_module_available,
378378
)
379379
_JSONARGPARSE_AVAILABLE = _module_available("jsonargparse")

pytorch_lightning/callbacks/quantization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333

3434
import pytorch_lightning as pl
3535
from pytorch_lightning.callbacks.base import Callback
36-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
36+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
3737
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3838

39-
if _TORCH_GREATER_EQUAL_DEV_1_10:
39+
if _TORCH_GREATER_EQUAL_1_10:
4040
from torch.ao.quantization.qconfig import QConfig
4141
else:
4242
from torch.quantization import QConfig
@@ -245,7 +245,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
245245
# version=None corresponds to using FakeQuantize rather than
246246
# FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
247247
# details in https://github.com/pytorch/pytorch/issues/64564
248-
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_DEV_1_10 else {}
248+
extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
249249
pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
250250

251251
elif isinstance(self._qconfig, QConfig):

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4040
from pytorch_lightning.utilities import (
4141
_IS_WINDOWS,
42-
_TORCH_GREATER_EQUAL_DEV_1_10,
42+
_TORCH_GREATER_EQUAL_1_10,
4343
GradClipAlgorithmType,
4444
rank_zero_deprecation,
4545
rank_zero_warn,
@@ -2043,7 +2043,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
20432043
20442044
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
20452045
"""
2046-
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
2046+
if not _TORCH_GREATER_EQUAL_1_10 or _IS_WINDOWS:
20472047
return
20482048

20492049
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
24-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10, AMPType
24+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

27-
if _TORCH_GREATER_EQUAL_DEV_1_10:
27+
if _TORCH_GREATER_EQUAL_1_10:
2828
from torch import autocast
2929
else:
3030
from torch.cuda.amp import autocast
@@ -47,7 +47,7 @@ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> No
4747

4848
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
4949
if precision == "bf16":
50-
if not _TORCH_GREATER_EQUAL_DEV_1_10:
50+
if not _TORCH_GREATER_EQUAL_1_10:
5151
raise MisconfigurationException(
5252
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
5353
)
@@ -97,7 +97,7 @@ def optimizer_step(
9797
self.scaler.update()
9898

9999
def autocast_context_manager(self) -> autocast:
100-
if _TORCH_GREATER_EQUAL_DEV_1_10:
100+
if _TORCH_GREATER_EQUAL_1_10:
101101
return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype)
102102
return autocast()
103103

pytorch_lightning/utilities/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
_TORCH_GREATER_EQUAL_1_8,
4949
_TORCH_GREATER_EQUAL_1_9,
5050
_TORCH_GREATER_EQUAL_1_10,
51-
_TORCH_GREATER_EQUAL_DEV_1_10,
5251
_TORCH_QUANTIZE_AVAILABLE,
5352
_TORCHTEXT_AVAILABLE,
5453
_TORCHVISION_AVAILABLE,

pytorch_lightning/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
7575
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
7676
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
7777
_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0")
78-
_TORCH_GREATER_EQUAL_DEV_1_10 = _compare_version("torch", operator.ge, "1.10.0", use_base_version=True)
78+
# _TORCH_GREATER_EQUAL_DEV_1_11 = _compare_version("torch", operator.ge, "1.11.0", use_base_version=True)
7979

8080
_APEX_AVAILABLE = _module_available("apex.amp")
8181
_DEEPSPEED_AVAILABLE = _module_available("deepspeed")

tests/core/test_lightning_module.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.loggers import TensorBoardLogger
24-
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625
from tests.helpers import BoringModel
2726
from tests.helpers.runif import RunIf
@@ -312,10 +311,7 @@ def __init__(self, spec):
312311
self.sharded_tensor.local_shards()[0].tensor.fill_(0)
313312

314313

315-
@pytest.mark.skipif(
316-
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
317-
)
318-
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
314+
@RunIf(min_torch="1.10", skip_windows=True)
319315
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
320316
spec = dist._sharding_spec.ChunkShardingSpec(
321317
dim=0,

tests/models/test_amp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import tests.helpers.utils as tutils
2323
from pytorch_lightning import Trainer
2424
from pytorch_lightning.plugins.environments import SLURMEnvironment
25-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10
2625
from tests.helpers import BoringModel, RandomDataset
2726
from tests.helpers.runif import RunIf
2827

@@ -68,7 +67,7 @@ def _assert_autocast_enabled(self):
6867
assert torch.is_autocast_enabled()
6968

7069

71-
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
70+
@RunIf(min_torch="1.10")
7271
@pytest.mark.parametrize(
7372
"strategy",
7473
[
@@ -95,8 +94,7 @@ def test_amp_cpus(tmpdir, strategy, precision, num_processes):
9594
assert trainer.state.finished, f"Training failed with {trainer.state}"
9695

9796

98-
@RunIf(min_gpus=2)
99-
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
97+
@RunIf(min_gpus=2, min_torch="1.10")
10098
@pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"])
10199
@pytest.mark.parametrize("precision", [16, "bf16"])
102100
@pytest.mark.parametrize("gpus", [1, 2])

tests/plugins/test_amp_plugins.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
2323
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
24-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
2524
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2625
from tests.helpers import BoringModel
2726
from tests.helpers.runif import RunIf
@@ -178,7 +177,7 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
178177
trainer.fit(model)
179178

180179

181-
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Torch CPU AMP is not available.")
180+
@RunIf(min_torch="1.10")
182181
def test_cpu_amp_precision_context_manager(tmpdir):
183182
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
184183
plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
@@ -197,7 +196,7 @@ def test_precision_selection_raises(monkeypatch):
197196

198197
import pytorch_lightning.plugins.precision.native_amp as amp
199198

200-
monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_DEV_1_10", False)
199+
monkeypatch.setattr(amp, "_TORCH_GREATER_EQUAL_1_10", False)
201200
with pytest.warns(
202201
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
203202
), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):

0 commit comments

Comments
 (0)