Skip to content

Commit e806bb7

Browse files
awaelchlitchaton
andauthored
Refactor LightningDistributedDataParallel (#5185)
* add wrapper * add squeeze * replace LightningDistributedDP * update import * module access * inputs * refactor warning * update * resolve flake8 * remove old class * set find unused params to False * update docstrings * update docs * update docs * add changelog * deprecation * rename wrapper -> module * rename pl_module * add unit tests * Revert "add changelog" This reverts commit 02ec0a6. * Revert "set find unused params to False" This reverts commit 8e45151. Co-authored-by: Ubuntu <[email protected]>
1 parent 6130813 commit e806bb7

File tree

12 files changed

+195
-105
lines changed

12 files changed

+195
-105
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777
- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))
7878

7979

80-
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5435](https://github.com/PyTorchLightning/pytorch-lightning/pull/5435))
80+
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))
81+
8182

8283
### Deprecated
8384

pytorch_lightning/accelerators/accelerator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,6 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
104104

105105
# once backward has been applied, release graph
106106
closure_loss = closure_loss.detach()
107-
108-
if not automatic_optimization and self.ddp_plugin is not None:
109-
# Manually prepare for reduce as user calling backwards manually
110-
self.ddp_plugin.on_after_manual_backward(self.trainer.model)
111107
return closure_loss
112108

113109
def clip_gradients(self, optimizer, clip_val=None):

pytorch_lightning/core/hooks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,8 +564,7 @@ def transfer_batch_to_device(self, batch, device)
564564
Note:
565565
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
566566
for your custom batch objects, you need to define your custom
567-
:class:`~torch.nn.parallel.DistributedDataParallel` or
568-
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
567+
:class:`~torch.nn.parallel.DistributedDataParallel` and
569568
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
570569
571570
See Also:

pytorch_lightning/overrides/data_parallel.py

Lines changed: 66 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
import itertools
1616
import threading
17+
import warnings
1718
from collections.abc import Iterable, Mapping
1819
from itertools import chain
19-
from typing import Optional
20+
from typing import Any, Optional
2021

2122
import torch
2223
from torch import Tensor
@@ -25,6 +26,7 @@
2526
from torch.nn.parallel import DistributedDataParallel
2627
from torch.nn.parallel._functions import Gather
2728

29+
from pytorch_lightning.core.lightning import LightningModule
2830
from pytorch_lightning.core.step_result import Result
2931
from pytorch_lightning.utilities.warnings import WarningCache
3032

@@ -150,73 +152,75 @@ def parallel_apply(self, replicas, inputs, kwargs):
150152

151153

152154
class LightningDistributedDataParallel(DistributedDataParallel):
153-
"""
154-
Override the forward call in lightning so it goes to training and validation step respectively
155-
"""
156-
PREPARE_FOR_BACKWARDS = True
157155

158-
def parallel_apply(self, replicas, inputs, kwargs):
159-
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
156+
def __init__(self, module: LightningModule, *args, **kwargs):
157+
warnings.warn(
158+
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
159+
" From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.",
160+
DeprecationWarning
161+
)
162+
super().__init__(LightningDistributedModule(module), *args, **kwargs)
160163

161-
def forward(self, *inputs, **kwargs): # pragma: no-cover
162-
self._sync_params()
163-
self.reducer_reset_hooks()
164-
fx_called: str = ''
165-
166-
if self.device_ids:
167-
168-
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
169-
if len(self.device_ids) == 1:
170-
# --------------
171-
# LIGHTNING MOD
172-
# --------------
173-
# normal
174-
# output = self.module(*inputs[0], **kwargs[0])
175-
# lightning
176-
if self.module.training:
177-
output = self.module.training_step(*inputs[0], **kwargs[0])
178-
fx_called = 'training_step'
179-
elif self.module.testing:
180-
output = self.module.test_step(*inputs[0], **kwargs[0])
181-
fx_called = 'test_step'
182-
else:
183-
output = self.module.validation_step(*inputs[0], **kwargs[0])
184-
fx_called = 'validation_step'
185-
else:
186-
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
187-
output = self.gather(outputs, self.output_device)
188-
else:
189-
# output = self.module(*inputs, **kwargs)
190-
# normal lightning (ddp_cpu)
191-
if self.module.training:
192-
output = self.module.training_step(*inputs, **kwargs)
193-
elif self.module.testing:
194-
output = self.module.test_step(*inputs, **kwargs)
195-
else:
196-
output = self.module.validation_step(*inputs, **kwargs)
197164

198-
if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
199-
self.reducer_prepare_for_backwards(output)
165+
class LightningDistributedModule(torch.nn.Module):
166+
167+
def __init__(self, pl_module: LightningModule):
168+
"""
169+
Wraps the user's LightningModule and redirects the forward call to the appropriate
170+
method, either ``training_step``, ``validation_step`` or ```test_step``.
171+
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
172+
shown in the example.
200173
201-
if output is None:
202-
warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
174+
Example:
175+
176+
ddp_model = DistributedDataParallel(
177+
module=LightningDistributedModule(lightning_module),
178+
device_ids=[local_rank],
179+
...
180+
)
181+
182+
Args:
183+
pl_module: the model to wrap
184+
185+
"""
186+
super().__init__()
187+
self.module = pl_module
188+
189+
def forward(self, *inputs, **kwargs):
190+
if self.module.training:
191+
output = self.module.training_step(*inputs, **kwargs)
192+
warn_if_output_is_none(output, "training_step")
193+
elif self.module.testing:
194+
output = self.module.test_step(*inputs, **kwargs)
195+
warn_if_output_is_none(output, "test_step")
196+
else:
197+
output = self.module.validation_step(*inputs, **kwargs)
198+
warn_if_output_is_none(output, "validation_step")
203199
return output
204200

205-
def reducer_prepare_for_backwards(self, output):
206-
self._reducer_prepared_for_backwards = True
207-
if torch.is_grad_enabled():
208-
# We'll return the output object verbatim since it is a freeform
209-
# object. We need to find any tensors in this object, though,
210-
# because we need to figure out which parameters were used during
211-
# this forward pass, to ensure we short circuit reduction for any
212-
# unused parameters. Only if `find_unused_parameters` is set.
213-
if self.find_unused_parameters:
214-
self.reducer.prepare_for_backward(list(_find_tensors(output)))
215-
else:
216-
self.reducer.prepare_for_backward([])
217-
218-
def reducer_reset_hooks(self):
219-
self._reducer_prepared_for_backwards = False
201+
202+
# In manual_optimization, we need to call reducer prepare_for_backward.
203+
# Note: Keep track of Pytorch DDP and update if there is a change
204+
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
205+
def prepare_for_backward(model: DistributedDataParallel, output: Any):
206+
if torch.is_grad_enabled() and model.require_backward_grad_sync:
207+
model.require_forward_param_sync = True
208+
# We'll return the output object verbatim since it is a freeform
209+
# object. We need to find any tensors in this object, though,
210+
# because we need to figure out which parameters were used during
211+
# this forward pass, to ensure we short circuit reduction for any
212+
# unused parameters. Only if `find_unused_parameters` is set.
213+
if model.find_unused_parameters:
214+
model.reducer.prepare_for_backward(list(_find_tensors(output)))
215+
else:
216+
model.reducer.prepare_for_backward([])
217+
else:
218+
model.require_forward_param_sync = False
219+
220+
221+
def warn_if_output_is_none(output: Any, method_name: str) -> None:
222+
if output is None:
223+
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
220224

221225

222226
def warn_missing_output(fx_called):

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
from typing import Any, Dict, List, Union
1717

1818
import torch.distributed as torch_distrib
19+
from torch.nn.parallel.distributed import DistributedDataParallel
1920
from torch.optim import Optimizer
2021

2122
from pytorch_lightning import _logger as log
2223
from pytorch_lightning.core.lightning import LightningModule
23-
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
24+
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward
2425
from pytorch_lightning.plugins.plugin import LightningPlugin
2526
from pytorch_lightning.utilities import DeviceType
2627

@@ -29,15 +30,14 @@ class DDPPlugin(LightningPlugin):
2930
"""
3031
Plugin to link a custom ddp implementation to any arbitrary accelerator.
3132
32-
This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
33-
which in turn forwards all args to `DistributedDataParallel`.
33+
This plugin forwards all constructor arguments to :class:`~torch.nn.parallel.DistributedDataParallel`.
3434
3535
Example::
3636
3737
class MyDDP(DDPPlugin):
3838
3939
def configure_ddp(self, model, device_ids):
40-
model = MyDDPWrapper(model, device_ids)
40+
model = MyDDPWrapper(LightningDistributedModule(model), device_ids)
4141
return model
4242
4343
my_ddp = MyDDP()
@@ -49,32 +49,40 @@ def __init__(self, **kwargs):
4949

5050
def configure_ddp(
5151
self, model: LightningModule, device_ids: List[int]
52-
) -> LightningDistributedDataParallel:
52+
) -> DistributedDataParallel:
5353
"""
54-
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
54+
Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`.
5555
Override to define a custom DDP implementation.
5656
57-
.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
58-
57+
.. note:: This requires that your DDP implementation subclasses
58+
:class:`~torch.nn.parallel.DistributedDataParallel` and that
59+
the original LightningModule gets wrapped by
60+
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`.
5961
6062
The default implementation is::
6163
6264
def configure_ddp(self, model, device_ids):
63-
model = LightningDistributedDataParallel(
64-
model, device_ids=device_ids, **self._ddp_kwargs
65+
model = DistributedDataParallel(
66+
LightningDistributedModule(model),
67+
device_ids=device_ids,
68+
**self._ddp_kwargs,
6569
)
6670
return model
6771
6872
Args:
69-
model: the lightningModule
73+
model: the LightningModule
7074
device_ids: the list of devices available
7175
7276
Returns:
73-
the model wrapped in LightningDistributedDataParallel
77+
the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel`
7478
7579
"""
76-
model = LightningDistributedDataParallel(
77-
model,
80+
# if unset, default `find_unused_parameters` `True`
81+
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
82+
"find_unused_parameters", True
83+
)
84+
model = DistributedDataParallel(
85+
module=LightningDistributedModule(model),
7886
device_ids=device_ids,
7987
**self._ddp_kwargs,
8088
)
@@ -131,7 +139,7 @@ def on_after_setup_optimizers(self, trainer):
131139

132140
def get_model_from_plugin(
133141
self,
134-
model: Union[LightningDistributedDataParallel, LightningModule]
142+
model: Union[DistributedDataParallel, LightningModule]
135143
) -> LightningModule:
136144
"""
137145
Override to modify returning base :class:`LightningModule`
@@ -147,24 +155,23 @@ def get_model_from_plugin(
147155
Returns: Reference :class:`LightningModule` within parallel wrapper.
148156
149157
"""
150-
if isinstance(model, LightningDistributedDataParallel):
151-
return model.module
158+
if isinstance(model, DistributedDataParallel):
159+
model = model.module
160+
if isinstance(model, LightningDistributedModule):
161+
model = model.module
152162
return model
153163

154164
@contextmanager
155-
def block_backward_sync(self, model: LightningDistributedDataParallel):
165+
def block_backward_sync(self, model: DistributedDataParallel):
156166
"""
157167
Blocks ddp sync gradients behaviour on backwards pass.
158168
This is useful for skipping sync when accumulating gradients, reducing communication overhead
159169
Returns: context manager with sync behaviour off
160170
"""
161171
yield model.no_sync()
162172

163-
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
164-
model.reducer_prepare_for_backwards(output)
165-
166-
def on_after_manual_backward(self, model: LightningDistributedDataParallel):
167-
model.reducer_reset_hooks()
173+
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
174+
prepare_for_backward(model, output)
168175

169176
def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
170177
return distributed_sampler_kwargs

pytorch_lightning/plugins/ddp_sequential_plugin.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from pytorch_lightning import LightningModule
2323
from pytorch_lightning import _logger as log
24-
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
2524
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2625
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
2726
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -137,7 +136,7 @@ def init_ddp_connection(
137136
self._infer_model_balance(trainer)
138137
self._assert_valid_model_balance(trainer)
139138

140-
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
139+
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
141140
pass
142141

143142
def _infer_model_balance(self, trainer):
@@ -267,10 +266,10 @@ def _check_arguments(self, trainer):
267266
def configure_ddp(
268267
self,
269268
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
270-
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
269+
model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
271270
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
272-
ddp_plugin.PREPARE_FOR_BACKWARDS = False
273-
return ddp_plugin
271+
model.require_backward_grad_sync = False
272+
return model
274273

275274
@rank_zero_only
276275
def rpc_save_model(

pytorch_lightning/plugins/sharded_plugin.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List, Optional, Union, Any
14+
from typing import Any, List, Optional, Union
1515

1616
from pytorch_lightning.core.lightning import LightningModule
1717
from pytorch_lightning.core.optimizer import is_lightning_optimizer
@@ -97,6 +97,3 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:
9797

9898
def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
9999
pass
100-
101-
def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
102-
pass

pytorch_lightning/trainer/properties.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def progress_bar_callback(self):
188188
@property
189189
def progress_bar_dict(self) -> dict:
190190
""" Read-only for progress bar metrics. """
191-
ref_model = self.model if not self.data_parallel else self.model.module
191+
ref_model = self.get_model()
192192
ref_model = cast(LightningModule, ref_model)
193193
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)
194194

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ def setup_training(self, model: LightningModule):
137137
# --------------------------
138138
# Setup??
139139
# --------------------------
140-
ref_model = model
141-
if self.trainer.data_parallel:
142-
ref_model = model.module
140+
ref_model = self.trainer.get_model()
143141

144142
# set the ranks and devices
145143
self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank

0 commit comments

Comments
 (0)