diff --git a/docs/source/advanced/mixed_precision.rst b/docs/source/advanced/mixed_precision.rst index 09547052403de..1c98f663ed5f3 100644 --- a/docs/source/advanced/mixed_precision.rst +++ b/docs/source/advanced/mixed_precision.rst @@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision. .. testcode:: - :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 + :skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available() Trainer(gpus=1, precision="bf16") diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6f03fd93694d2..4546cc0c80af9 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import ( + _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10, GradClipAlgorithmType, rank_zero_deprecation, @@ -2041,7 +2042,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ - if not _TORCH_GREATER_EQUAL_DEV_1_10: + if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS: return from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 64fc1a5a97277..4499c1d7dfc41 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -42,6 +42,7 @@ from pytorch_lightning.utilities import ( _FAIRSCALE_AVAILABLE, _HYDRA_AVAILABLE, + _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, @@ -57,7 +58,9 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_10: - from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer + if not _IS_WINDOWS: + from torch.distributed.optim import DistributedOptimizer + from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS @@ -333,8 +336,9 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int): if isinstance(optimizer, LightningOptimizer): optimizer = optimizer._optimizer + is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False if ( - isinstance(optimizer, DistributedOptimizer) + is_distributed_optimizer or isinstance(optimizer, ZeroRedundancyOptimizer) or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)) ): diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 94d9159c9fef6..bc19aa1366a55 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -38,6 +38,7 @@ _HYDRA_EXPERIMENTAL_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, + _IS_WINDOWS, _JSONARGPARSE_AVAILABLE, _module_available, _OMEGACONF_AVAILABLE, diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 18260339d82e3..135e437c4bc54 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -21,7 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10 +from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10 from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -315,6 +315,7 @@ def __init__(self, spec): @pytest.mark.skipif( not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`" ) +@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows") def test_sharded_tensor_state_dict(tmpdir, single_process_pg): spec = dist._sharding_spec.ChunkShardingSpec( dim=0,