Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/advanced/mixed_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ BFloat16 Mixed precision is similar to FP16 mixed precision, however we maintain
Since BFloat16 is more stable than FP16 during training, we do not need to worry about any gradient scaling or nan gradient values that comes with using FP16 mixed precision.

.. testcode::
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10
:skipif: not _TORCH_GREATER_EQUAL_DEV_1_10 or not torch.cuda.is_available()

Trainer(gpus=1, precision="bf16")

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_DEV_1_10,
GradClipAlgorithmType,
rank_zero_deprecation,
Expand Down Expand Up @@ -2041,7 +2042,7 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if not _TORCH_GREATER_EQUAL_DEV_1_10:
if not _TORCH_GREATER_EQUAL_DEV_1_10 or _IS_WINDOWS:
return

from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pytorch_lightning.utilities import (
_FAIRSCALE_AVAILABLE,
_HYDRA_AVAILABLE,
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
Expand All @@ -57,7 +58,9 @@
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
if not _IS_WINDOWS:
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand Down Expand Up @@ -333,8 +336,9 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

is_distributed_optimizer = isinstance(optimizer, DistributedOptimizer) if not _IS_WINDOWS else False
if (
isinstance(optimizer, DistributedOptimizer)
is_distributed_optimizer
or isinstance(optimizer, ZeroRedundancyOptimizer)
or (_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS))
):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
_HYDRA_EXPERIMENTAL_AVAILABLE,
_IPU_AVAILABLE,
_IS_INTERACTIVE,
_IS_WINDOWS,
_JSONARGPARSE_AVAILABLE,
_module_available,
_OMEGACONF_AVAILABLE,
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_DEV_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -315,6 +315,7 @@ def __init__(self, spec):
@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Test requires the torch version to support `ShardedTensor`"
)
@pytest.mark.skipif(_IS_WINDOWS, reason="Not supported on Windows")
def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
spec = dist._sharding_spec.ChunkShardingSpec(
dim=0,
Expand Down