Skip to content

Commit 2fc0b67

Browse files
authored
Merge branch 'master' into fix/tpu_sync_dist
2 parents 548cad4 + 80c5293 commit 2fc0b67

File tree

11 files changed

+169
-18
lines changed

11 files changed

+169
-18
lines changed

.github/workflows/ci_test-base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ jobs:
3131
- name: Setup macOS
3232
if: runner.os == 'macOS'
3333
run: |
34+
brew update # todo: find a better fix (libomp error)
3435
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
3536
3637
# Note: This uses an internal pip API and may not always work

.github/workflows/ci_test-full.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
- name: Setup macOS
4242
if: runner.os == 'macOS'
4343
run: |
44+
brew update # todo: find a better fix (libomp error)
4445
brew install libomp # https://github.com/pytorch/pytorch/issues/20030
4546
brew install openmpi libuv # Horovod on macOS requires OpenMPI, Gloo not currently supported
4647

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
243243
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
244244

245245

246+
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))
247+
248+
246249
## [1.2.7] - 2021-04-06
247250

248251
### Fixed

dockers/base-cuda/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ RUN \
113113
pip install --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex && \
114114
rm -rf apex
115115

116+
RUN \
117+
# install DeepSpeed
118+
pip install deepspeed>=0.3.14
119+
116120
RUN \
117121
# Show what we have
118122
pip --version && \

pytorch_lightning/overrides/data_parallel.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.core.lightning import LightningModule
2323
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2424
from pytorch_lightning.overrides.distributed import LightningDistributedModule
25+
from pytorch_lightning.utilities import rank_zero_warn
2526
from pytorch_lightning.utilities.apply_func import apply_to_collection
2627

2728

@@ -71,6 +72,8 @@ def __init__(self, pl_module: LightningModule):
7172
super().__init__(pl_module)
7273

7374
def forward(self, *inputs, **kwargs):
75+
self.update_replica_device_attributes(inputs)
76+
# forward call will redirect to training_step, validation_step, etc.
7477
output = super().forward(*inputs, **kwargs)
7578

7679
def output_transform(data: Any):
@@ -85,6 +88,37 @@ def output_transform(data: Any):
8588
)
8689
return output
8790

91+
def update_replica_device_attributes(self, inputs: Any) -> None:
92+
"""
93+
Updates the device information of LightningModule by reading the device from the inputs.
94+
In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass
95+
are lost when the replicas get discarded. The only way to know the current device is from the
96+
inputs passed into the model.
97+
98+
Args:
99+
inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors,
100+
a warning is shown that accessing ``self.device`` will not return the correct device.
101+
"""
102+
replica_device = None
103+
104+
def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor:
105+
nonlocal replica_device
106+
if replica_device is None and tensor.device != torch.device("cpu"):
107+
replica_device = tensor.device
108+
return tensor
109+
110+
apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device)
111+
112+
if replica_device is not None:
113+
# by calling .to() we force the update to the self.device property
114+
self.module.to(device=replica_device)
115+
else:
116+
rank_zero_warn(
117+
"Could not determine on which device the inputs are."
118+
" When using DataParallel (accelerator='dp'), be aware that in case you are using self.device"
119+
" in your code, it will reference only the root device."
120+
)
121+
88122

89123
def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
90124
""" Converts a Python scalar number to a torch tensor and places it on the given device. """

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
cpu_checkpointing: bool = False,
103103
contiguous_memory_optimization: bool = False,
104104
synchronize_checkpoint_boundary: bool = False,
105+
save_full_weights: bool = True,
105106
) -> None:
106107
"""
107108
@@ -177,11 +178,16 @@ def __init__(
177178
Not supported by all models
178179
179180
synchronize_checkpoint_boundary: Insert :func:`torch.cuda.synchronize` at each checkpoint boundary.
181+
182+
save_full_weights: Gathers weights across all processes before saving to disk
183+
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
184+
rather than individual sharded weight files.
185+
Disable to save sharded states individually. (Default: True)
180186
"""
181187
if not _DEEPSPEED_AVAILABLE:
182188
raise MisconfigurationException(
183189
"To use the DeepSpeed plugin, you must have DeepSpeed installed."
184-
" pip install deepspeed mpi4py"
190+
" pip install deepspeed"
185191
)
186192
super().__init__(
187193
parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment
@@ -205,11 +211,13 @@ def __init__(
205211
allgather_partitions=allgather_partitions,
206212
reduce_scatter=reduce_scatter,
207213
allgather_bucket_size=allgather_bucket_size,
208-
reduce_bucket_size=reduce_bucket_size
214+
reduce_bucket_size=reduce_bucket_size,
209215
)
210216
self._config_initialized = False
211217
deepspeed.utils.logging.logger.setLevel(logging_level)
212218

219+
self.save_full_weights = save_full_weights
220+
213221
# default FP16 parameters.
214222
self.loss_scale = loss_scale
215223
self.initial_scale_power = initial_scale_power
@@ -472,17 +480,27 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None:
472480
"""Save model/training states as a checkpoint file through state-dump and file-write.
473481
474482
Args:
483+
checkpoint: The checkpoint state dictionary
475484
filepath: write-target file's path
476-
weights_only: saving model weights only
477485
"""
478486
if self.world_size > 1 and self.zero_stage_3:
487+
if self.save_full_weights:
488+
# todo: expose this as general function in deepspeed
489+
state_dict = self.deepspeed_engine._zero3_consolidated_fp16_state_dict()
490+
if self.is_global_zero:
491+
# State dict keys will include reference to wrapper LightningDeepSpeedModule
492+
# Delete `module` prefix before saving.
493+
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()}
494+
checkpoint['state_dict'] = state_dict
495+
return super().save_checkpoint(checkpoint, filepath)
496+
return
497+
479498
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
480499
# dump states as a checkpoint dictionary object
481500
save_dir = self._filepath_to_dir(filepath)
482501
_exclude_keys = ['state_dict', 'optimizer_states', 'lr_schedulers']
483502
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
484503
self.deepspeed_engine.save_checkpoint(save_dir, client_state=checkpoint)
485-
486504
else:
487505
super().save_checkpoint(checkpoint, filepath)
488506

@@ -491,7 +509,8 @@ def restore_model_state_from_ckpt_path(
491509
ckpt_path: str,
492510
map_location: Callable = lambda storage, loc: storage,
493511
) -> Tuple[Dict, bool]:
494-
if self.world_size > 1:
512+
if not self.save_full_weights and self.world_size > 1:
513+
# Rely on deepspeed to load the checkpoint and necessary information
495514
from pytorch_lightning.trainer.states import TrainerState
496515
stage_is_fit = self.lightning_module.trainer.state == TrainerState.FITTING
497516
save_dir = self._filepath_to_dir(ckpt_path)
@@ -511,6 +530,10 @@ def restore_model_state_from_ckpt_path(
511530
# hook: give user access to checkpoint if needed.
512531
self.lightning_module.on_load_checkpoint(client_state)
513532
return client_state, False
533+
534+
# Broadcast to ensure we load from the rank 0 checkpoint
535+
# This doesn't have to be the case when using deepspeed sharded checkpointing
536+
ckpt_path = self.broadcast(ckpt_path)
514537
return super().restore_model_state_from_ckpt_path(ckpt_path, map_location=map_location)
515538

516539
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:

pytorch_lightning/utilities/imports.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _compare_version(package: str, op, version) -> bool:
6969
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
7070
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
7171
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
72+
_TORCH_GREATER_EQUAL_1_8_1 = _compare_version("torch", operator.ge, "1.8.1")
7273
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")
7374

7475
_APEX_AVAILABLE = _module_available("apex.amp")
@@ -80,7 +81,7 @@ def _compare_version(package: str, op, version) -> bool:
8081
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
8182
_HYDRA_AVAILABLE = _module_available("hydra")
8283
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
83-
_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False
84+
_KINETO_AVAILABLE = _TORCH_GREATER_EQUAL_1_8_1 and torch.profiler.kineto_available()
8485
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
8586
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
8687
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')

requirements/extra.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,3 @@ hydra-core>=1.0
1010
# todo: when switch to standard package stream, drop `fairscale` from hard mocked docs libs
1111
https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip
1212
jsonargparse[signatures]>=3.3.1
13-
deepspeed>=0.3.13

tests/models/test_horovod.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def _compute_batch():
376376
horovod.run(_compute_batch, np=2)
377377

378378

379+
# todo: need to be fixed :]
380+
@pytest.mark.skip(reason="TODO Breaking CI: Aborted (core dumped)")
379381
@RunIf(skip_windows=True, horovod=True)
380382
def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
381383

tests/overrides/test_data_parallel.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
import pytest
44
import torch
5+
import torch.nn as nn
56
from torch.nn import DataParallel
67

8+
from pytorch_lightning import LightningModule
9+
from pytorch_lightning.core.decorators import auto_move_data
710
from pytorch_lightning.overrides import LightningDistributedModule
811
from pytorch_lightning.overrides.data_parallel import (
912
LightningParallelModule,
@@ -123,3 +126,68 @@ def training_step(self, batch, batch_idx):
123126
wrapped_model = LightningParallelModule(model)
124127
output = wrapped_model(batch, batch_idx)
125128
assert output["python scalar"] == torch.tensor([12.3], device=device)
129+
130+
131+
@RunIf(min_gpus=2)
132+
@pytest.mark.parametrize(
133+
"nest, unnest", [
134+
(lambda x: x, lambda x: x),
135+
(lambda x: dict(data=x), lambda x: x["data"]),
136+
(lambda x: [x, (x, x)], lambda x: x[1][0]),
137+
]
138+
)
139+
def test_lightning_parallel_module_device_access(nest, unnest):
140+
""" Test that self.device returns the correct value in replicas of DataParallel. """
141+
142+
class DeviceAccessModel(LightningModule):
143+
144+
def __init__(self):
145+
super().__init__()
146+
self.layer = nn.Linear(2, 3)
147+
148+
@auto_move_data
149+
def training_step(self, batch, batch_idx):
150+
batch = unnest(batch)
151+
assert batch.shape == torch.Size([1, 1])
152+
assert self.device.index == batch.item()
153+
assert self.device == self.layer.weight.device
154+
return torch.tensor(1, device=self.device)
155+
156+
pl_module = DeviceAccessModel()
157+
# required for redirecting the forward call to training_step
158+
pl_module.trainer = Mock()
159+
pl_module.trainer._running_stage = RunningStage.TRAINING
160+
161+
root_device = torch.device("cuda", 0)
162+
wrapped_module = LightningParallelModule(pl_module).to(root_device)
163+
model = DataParallel(wrapped_module, device_ids=[0, 1])
164+
165+
data = torch.tensor([0.0, 1.0], device=root_device).view(2, 1) # one value per gpu
166+
data = data.to(root_device)
167+
data = nest(data)
168+
output = model(data, 0)
169+
assert output.device == root_device
170+
assert pl_module.device == root_device
171+
assert torch.all(output.cpu().eq(torch.tensor([1, 1])))
172+
173+
174+
@RunIf(min_gpus=2)
175+
def test_lightning_parallel_module_device_access_warning():
176+
""" Test that we show a warning when the device can't be inferred from the input. """
177+
178+
class DeviceAccessModel(LightningModule):
179+
180+
def training_step(self, batch, batch_idx):
181+
pass
182+
183+
pl_module = DeviceAccessModel()
184+
# required for redirecting the forward call to training_step
185+
pl_module.trainer = Mock()
186+
pl_module.trainer._running_stage = RunningStage.TRAINING
187+
188+
wrapped_module = LightningParallelModule(pl_module).cuda()
189+
model = DataParallel(wrapped_module, device_ids=[0, 1])
190+
191+
data = dict(x=1) # contains no tensors
192+
with pytest.warns(UserWarning, match="Could not determine on which device the inputs are."):
193+
_ = model(data, 0)

0 commit comments

Comments
 (0)