Skip to content

Commit d30e456

Browse files
committed
Revert "Fix: skip importing DistributedOptimizer for Windows (#10071)"
This reverts commit c3614f1.
1 parent 3505c7c commit d30e456

File tree

5 files changed

+5
-12
lines changed

5 files changed

+5
-12
lines changed

docs/source/advanced/mixed_precision.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
5050
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.
5151

5252
.. testcode::
53-
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()
53+
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
5454

5555
Trainer(gpus=1, precision="bf16")
5656

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from pytorch_lightning.core.saving import ModelIO
3939
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
4040
from pytorch_lightning.utilities import (
41-
_IS_WINDOWS,
4241
_TORCH_GREATER_EQUAL_DEV_1_10,
4342
GradClipAlgorithmType,
4443
rank_zero_deprecation,
@@ -2042,7 +2041,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
20422041
20432042
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
20442043
"""
2045-
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
2044+
if not _TORCH_GREATER_EQUAL_DEV_1_10:
20462045
return
20472046

20482047
from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from pytorch_lightning.utilities import (
4343
_FAIRSCALE_AVAILABLE,
4444
_HYDRA_AVAILABLE,
45-
_IS_WINDOWS,
4645
_TORCH_GREATER_EQUAL_1_7,
4746
_TORCH_GREATER_EQUAL_1_8,
4847
_TORCH_GREATER_EQUAL_1_9,
@@ -58,9 +57,7 @@
5857
from pytorch_lightning.utilities.types import STEP_OUTPUT
5958

6059
if _TORCH_GREATER_EQUAL_1_10:
61-
if not _IS_WINDOWS:
62-
from torch.distributed.optim import DistributedOptimizer
63-
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
60+
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
6461

6562
if _FAIRSCALE_AVAILABLE:
6663
from fairscale.optim import OSS
@@ -336,9 +333,8 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
336333
if isinstance(optimizer, LightningOptimizer):
337334
optimizer = optimizer._optimizer
338335

339-
is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
340336
if (
341-
is_distributed_optimizer
337+
isinstance(optimizer, DistributedOptimizer)
342338
or isinstance(optimizer, ZeroRedundancyOptimizer)
343339
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
344340
):

pytorch_lightning/utilities/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
_HYDRA_EXPERIMENTAL_AVAILABLE,
3939
_IPU_AVAILABLE,
4040
_IS_INTERACTIVE,
41-
_IS_WINDOWS,
4241
_JSONARGPARSE_AVAILABLE,
4342
_module_available,
4443
_OMEGACONF_AVAILABLE,

tests/core/test_lightning_module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.loggers import TensorBoardLogger
24-
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
24+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from tests.helpers import BoringModel
2727
from tests.helpers.runif import RunIf
@@ -315,7 +315,6 @@ def __init__(self, spec):
315315
@pytest.mark.skipif(
316316
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
317317
)
318-
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
319318
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
320319
spec = dist._sharding_spec.ChunkShardingSpec(
321320
dim=0,

0 commit comments

Comments
 (0)