Skip to content

Commit 7d32ceb

Browse files
committed
Merge branch 'master' into bugfix-#6947-Checkpoint-issue-when-using-Horovod-distributed-backend
2 parents f10b2a4 + 6c01608 commit 7d32ceb

File tree

18 files changed

+228
-43
lines changed

18 files changed

+228
-43
lines changed

.github/workflows/ci_test-base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- name: Setup macOS
3232
if: runner.os == 'macOS'
3333
run: |
34+
brew update # todo: find a better fix (libomp error)
3435
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
3536
3637
# Note: This uses an internal pip API and may not always work

.github/workflows/ci_test-full.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
- name: Setup macOS
4242
if: runner.os == 'macOS'
4343
run: |
44+
brew update # todo: find a better fix (libomp error)
4445
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
4546
brew install openmpi libuv # Horovod on macOS requires OpenMPI, Gloo not currently supported
4647

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
240240
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
241241

242242

243+
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
244+
245+
246+
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))
247+
248+
243249
## [1.2.7] - 2021-04-06
244250

245251
### Fixed

dockers/base-cuda/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ RUN \
113113
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \
114114
rm -rf apex
115115

116+
RUN \
117+
# install DeepSpeed
118+
pip install deepspeed>=0.3.14
119+
116120
RUN \
117121
# Show what we have
118122
pip --version && \

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
106106
self.precision_plugin.pre_dispatch()
107107

108108
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
109-
"""Hook to do something before the training/evaluation/prediction starts."""
109+
"""Hook to do something after the training/evaluation/prediction starts."""
110110
self.training_type_plugin.post_dispatch()
111111
self.precision_plugin.post_dispatch()
112112

pytorch_lightning/core/step_result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import Tensor
2222
from torchmetrics import Metric
2323

24-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
24+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
2525

2626

2727
class Result(Dict):
@@ -105,10 +105,11 @@ def log(
105105

106106
# sync across workers when using distributed training
107107
sync_fn = sync_fn or sync_ddp_if_available
108+
108109
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
109110
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
110111
# TODO: Find a way to make the reduction only once, so we don't need to clone.
111-
if is_dist_initialized and isinstance(value, torch.Tensor):
112+
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
112113
value = value.clone()
113114
else:
114115
value = torch.tensor(value, device=device, dtype=torch.float)

pytorch_lightning/overrides/data_parallel.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2424
from pytorch_lightning.overrides.distributed import LightningDistributedModule
25+
from pytorch_lightning.utilities import rank_zero_warn
2526
from pytorch_lightning.utilities.apply_func import apply_to_collection
2627

2728

@@ -71,6 +72,8 @@ def __init__(self, pl_module: LightningModule):
7172
super().__init__(pl_module)
7273

7374
def forward(self, *inputs, **kwargs):
75+
self.update_replica_device_attributes(inputs)
76+
# forward call will redirect to training_step, validation_step, etc.
7477
output = super().forward(*inputs, **kwargs)
7578

7679
def output_transform(data: Any):
@@ -85,6 +88,37 @@ def output_transform(data: Any):
8588
)
8689
return output
8790

91+
def update_replica_device_attributes(self, inputs: Any) -> None:
92+
"""
93+
Updates the device information of LightningModule by reading the device from the inputs.
94+
In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass
95+
are lost when the replicas get discarded. The only way to know the current device is from the
96+
inputs passed into the model.
97+
98+
Args:
99+
inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors,
100+
a warning is shown that accessing ``self.device`` will not return the correct device.
101+
"""
102+
replica_device = None
103+
104+
def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor:
105+
nonlocal replica_device
106+
if replica_device is None and tensor.device != torch.device("cpu"):
107+
replica_device = tensor.device
108+
return tensor
109+
110+
apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device)
111+
112+
if replica_device is not None:
113+
# by calling .to() we force the update to the self.device property
114+
self.module.to(device=replica_device)
115+
else:
116+
rank_zero_warn(
117+
"Could not determine on which device the inputs are."
118+
" When using DataParallel (accelerator='dp'), be aware that in case you are using self.device"
119+
" in your code, it will reference only the root device."
120+
)
121+
88122

89123
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
90124
""" Converts a Python scalar number to a torch tensor and places it on the given device. """

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
cpu_checkpointing: bool = False,
103103
contiguous_memory_optimization: bool = False,
104104
synchronize_checkpoint_boundary: bool = False,
105+
save_full_weights: bool = True,
105106
) -> None:
106107
"""
107108
@@ -177,11 +178,16 @@ def __init__(
177178
Not supported by all models
178179
179180
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
181+
182+
save_full_weights: Gathers weights across all processes before saving to disk
183+
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
184+
rather than individual sharded weight files.
185+
Disable to save sharded states individually. (Default: True)
180186
"""
181187
if not _DEEPSPEED_AVAILABLE:
182188
raise MisconfigurationException(
183189
"To use the DeepSpeed plugin, you must have DeepSpeed installed."
184-
" pip install deepspeed mpi4py"
190+
" pip install deepspeed"
185191
)
186192
super().__init__(
187193
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
@@ -205,11 +211,13 @@ def __init__(
205211
allgather_partitions=allgather_partitions,
206212
reduce_scatter=reduce_scatter,
207213
allgather_bucket_size=allgather_bucket_size,
208-
reduce_bucket_size=reduce_bucket_size
214+
reduce_bucket_size=reduce_bucket_size,
209215
)
210216
self._config_initialized = False
211217
deepspeed.utils.logging.logger.setLevel(logging_level)
212218

219+
self.save_full_weights = save_full_weights
220+
213221
# default FP16 parameters.
214222
self.loss_scale = loss_scale
215223
self.initial_scale_power = initial_scale_power
@@ -472,17 +480,27 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
472480
"""Save model/training states as a checkpoint file through state-dump and file-write.
473481
474482
Args:
483+
checkpoint: The checkpoint state dictionary
475484
filepath: write-target file's path
476-
weights_only: saving model weights only
477485
"""
478486
if self.world_size > 1 and self.zero_stage_3:
487+
if self.save_full_weights:
488+
# todo: expose this as general function in deepspeed
489+
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
490+
if self.is_global_zero:
491+
# State dict keys will include reference to wrapper LightningDeepSpeedModule
492+
# Delete `module` prefix before saving.
493+
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
494+
checkpoint['state_dict'] = state_dict
495+
return super().save_checkpoint(checkpoint, filepath)
496+
return
497+
479498
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480499
# dump states as a checkpoint dictionary object
481500
save_dir = self._filepath_to_dir(filepath)
482501
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
483502
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
484503
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
485-
486504
else:
487505
super().save_checkpoint(checkpoint, filepath)
488506

@@ -491,7 +509,8 @@ def restore_model_state_from_ckpt_path(
491509
ckpt_path: str,
492510
map_location: Callable = lambda storage, loc: storage,
493511
) -> Tuple[Dict, bool]:
494-
if self.world_size > 1:
512+
if not self.save_full_weights and self.world_size > 1:
513+
# Rely on deepspeed to load the checkpoint and necessary information
495514
from pytorch_lightning.trainer.states import TrainerState
496515
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
497516
save_dir = self._filepath_to_dir(ckpt_path)
@@ -511,6 +530,10 @@ def restore_model_state_from_ckpt_path(
511530
# hook: give user access to checkpoint if needed.
512531
self.lightning_module.on_load_checkpoint(client_state)
513532
return client_state, False
533+
534+
# Broadcast to ensure we load from the rank 0 checkpoint
535+
# This doesn't have to be the case when using deepspeed sharded checkpointing
536+
ckpt_path = self.broadcast(ckpt_path)
514537
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)
515538

516539
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import re
1717
import time
18-
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
18+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
1919

2020
import torch
2121
import torch.multiprocessing as mp
@@ -41,7 +41,6 @@
4141
if _OMEGACONF_AVAILABLE:
4242
from omegaconf import DictConfig, ListConfig, OmegaConf
4343

44-
4544
if TYPE_CHECKING:
4645
from torch.nn import Module
4746
from torch.utils.data import DataLoader
@@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
278277
Return:
279278
A tensor of shape (world_size, batch, ...)
280279
"""
281-
return xm.all_gather(tensor.unsqueeze(0))
280+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
281+
tensor = tensor.unsqueeze(0)
282+
return xm.all_gather(tensor)

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,10 @@
5353
_TORCH_QUANTIZE_AVAILABLE,
5454
_TORCHTEXT_AVAILABLE,
5555
_TORCHVISION_AVAILABLE,
56+
_TPU_AVAILABLE,
5657
_XLA_AVAILABLE,
5758
)
5859
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
59-
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401
60-
61-
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
6260

6361
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
6462
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

0 commit comments

Comments
 (0)