|
14 | 14 |
|
15 | 15 | import itertools |
16 | 16 | import threading |
| 17 | +import warnings |
17 | 18 | from collections.abc import Iterable, Mapping |
18 | 19 | from itertools import chain |
19 | | -from typing import Optional |
| 20 | +from typing import Any, Optional |
20 | 21 |
|
21 | 22 | import torch |
22 | 23 | from torch import Tensor |
|
25 | 26 | from torch.nn.parallel import DistributedDataParallel |
26 | 27 | from torch.nn.parallel._functions import Gather |
27 | 28 |
|
| 29 | +from pytorch_lightning.core.lightning import LightningModule |
28 | 30 | from pytorch_lightning.core.step_result import Result |
29 | 31 | from pytorch_lightning.utilities.warnings import WarningCache |
30 | 32 |
|
@@ -150,73 +152,75 @@ def parallel_apply(self, replicas, inputs, kwargs): |
150 | 152 |
|
151 | 153 |
|
152 | 154 | class LightningDistributedDataParallel(DistributedDataParallel): |
153 | | - """ |
154 | | - Override the forward call in lightning so it goes to training and validation step respectively |
155 | | - """ |
156 | | - PREPARE_FOR_BACKWARDS = True |
157 | 155 |
|
158 | | - def parallel_apply(self, replicas, inputs, kwargs): |
159 | | - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) |
| 156 | + def __init__(self, module: LightningModule, *args, **kwargs): |
| 157 | + warnings.warn( |
| 158 | + "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." |
| 159 | + " From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.", |
| 160 | + DeprecationWarning |
| 161 | + ) |
| 162 | + super().__init__(LightningDistributedModule(module), *args, **kwargs) |
160 | 163 |
|
161 | | - def forward(self, *inputs, **kwargs): # pragma: no-cover |
162 | | - self._sync_params() |
163 | | - self.reducer_reset_hooks() |
164 | | - fx_called: str = '' |
165 | | - |
166 | | - if self.device_ids: |
167 | | - |
168 | | - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
169 | | - if len(self.device_ids) == 1: |
170 | | - # -------------- |
171 | | - # LIGHTNING MOD |
172 | | - # -------------- |
173 | | - # normal |
174 | | - # output = self.module(*inputs[0], **kwargs[0]) |
175 | | - # lightning |
176 | | - if self.module.training: |
177 | | - output = self.module.training_step(*inputs[0], **kwargs[0]) |
178 | | - fx_called = 'training_step' |
179 | | - elif self.module.testing: |
180 | | - output = self.module.test_step(*inputs[0], **kwargs[0]) |
181 | | - fx_called = 'test_step' |
182 | | - else: |
183 | | - output = self.module.validation_step(*inputs[0], **kwargs[0]) |
184 | | - fx_called = 'validation_step' |
185 | | - else: |
186 | | - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) |
187 | | - output = self.gather(outputs, self.output_device) |
188 | | - else: |
189 | | - # output = self.module(*inputs, **kwargs) |
190 | | - # normal lightning (ddp_cpu) |
191 | | - if self.module.training: |
192 | | - output = self.module.training_step(*inputs, **kwargs) |
193 | | - elif self.module.testing: |
194 | | - output = self.module.test_step(*inputs, **kwargs) |
195 | | - else: |
196 | | - output = self.module.validation_step(*inputs, **kwargs) |
197 | 164 |
|
198 | | - if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS: |
199 | | - self.reducer_prepare_for_backwards(output) |
| 165 | +class LightningDistributedModule(torch.nn.Module): |
| 166 | + |
| 167 | + def __init__(self, pl_module: LightningModule): |
| 168 | + """ |
| 169 | + Wraps the user's LightningModule and redirects the forward call to the appropriate |
| 170 | + method, either ``training_step``, ``validation_step`` or ```test_step``. |
| 171 | + This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as |
| 172 | + shown in the example. |
200 | 173 |
|
201 | | - if output is None: |
202 | | - warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') |
| 174 | + Example: |
| 175 | +
|
| 176 | + ddp_model = DistributedDataParallel( |
| 177 | + module=LightningDistributedModule(lightning_module), |
| 178 | + device_ids=[local_rank], |
| 179 | + ... |
| 180 | + ) |
| 181 | +
|
| 182 | + Args: |
| 183 | + pl_module: the model to wrap |
| 184 | +
|
| 185 | + """ |
| 186 | + super().__init__() |
| 187 | + self.module = pl_module |
| 188 | + |
| 189 | + def forward(self, *inputs, **kwargs): |
| 190 | + if self.module.training: |
| 191 | + output = self.module.training_step(*inputs, **kwargs) |
| 192 | + warn_if_output_is_none(output, "training_step") |
| 193 | + elif self.module.testing: |
| 194 | + output = self.module.test_step(*inputs, **kwargs) |
| 195 | + warn_if_output_is_none(output, "test_step") |
| 196 | + else: |
| 197 | + output = self.module.validation_step(*inputs, **kwargs) |
| 198 | + warn_if_output_is_none(output, "validation_step") |
203 | 199 | return output |
204 | 200 |
|
205 | | - def reducer_prepare_for_backwards(self, output): |
206 | | - self._reducer_prepared_for_backwards = True |
207 | | - if torch.is_grad_enabled(): |
208 | | - # We'll return the output object verbatim since it is a freeform |
209 | | - # object. We need to find any tensors in this object, though, |
210 | | - # because we need to figure out which parameters were used during |
211 | | - # this forward pass, to ensure we short circuit reduction for any |
212 | | - # unused parameters. Only if `find_unused_parameters` is set. |
213 | | - if self.find_unused_parameters: |
214 | | - self.reducer.prepare_for_backward(list(_find_tensors(output))) |
215 | | - else: |
216 | | - self.reducer.prepare_for_backward([]) |
217 | | - |
218 | | - def reducer_reset_hooks(self): |
219 | | - self._reducer_prepared_for_backwards = False |
| 201 | + |
| 202 | +# In manual_optimization, we need to call reducer prepare_for_backward. |
| 203 | +# Note: Keep track of Pytorch DDP and update if there is a change |
| 204 | +# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 |
| 205 | +def prepare_for_backward(model: DistributedDataParallel, output: Any): |
| 206 | + if torch.is_grad_enabled() and model.require_backward_grad_sync: |
| 207 | + model.require_forward_param_sync = True |
| 208 | + # We'll return the output object verbatim since it is a freeform |
| 209 | + # object. We need to find any tensors in this object, though, |
| 210 | + # because we need to figure out which parameters were used during |
| 211 | + # this forward pass, to ensure we short circuit reduction for any |
| 212 | + # unused parameters. Only if `find_unused_parameters` is set. |
| 213 | + if model.find_unused_parameters: |
| 214 | + model.reducer.prepare_for_backward(list(_find_tensors(output))) |
| 215 | + else: |
| 216 | + model.reducer.prepare_for_backward([]) |
| 217 | + else: |
| 218 | + model.require_forward_param_sync = False |
| 219 | + |
| 220 | + |
| 221 | +def warn_if_output_is_none(output: Any, method_name: str) -> None: |
| 222 | + if output is None: |
| 223 | + warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') |
220 | 224 |
|
221 | 225 |
|
222 | 226 | def warn_missing_output(fx_called): |
|
0 commit comments