Skip to content

Commit e463248

Browse files
tchatonBorda
authored andcommitted
[bug-fix] Call transfer_batch_to_device in DDPlugin (#5195)
* hacking out * update * remove useless on_before_forward * update * remove overriden * iremove os * use on_before_forward * resolve flake8 * add test * update * add single_process_per_device * resolve flake8 * update * resolve * update * update * update * add comment * resolve bug with sharded * update * remove property * update * resolve test * resolve bug * update on comments * update doc * Update pytorch_lightning/core/hooks.py Co-authored-by: Rohit Gupta <[email protected]> * update on comments * Update pytorch_lightning/plugins/ddp_plugin.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/plugins/ddp_plugin.py Co-authored-by: Rohit Gupta <[email protected]> * resolve pep8 * add device_ids to pipe * update on comments * update * resolve * update * update * update Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Sean Naren <[email protected]> (cherry picked from commit d510707)
1 parent 2c44e5a commit e463248

File tree

14 files changed

+74
-21
lines changed

14 files changed

+74
-21
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
135135

136136
### Fixed
137137

138+
- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))
139+
140+
138141

139142
## [1.1.3] - 2021-01-05
140143

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def ddp_train(self, process_idx, mp_queue, model):
210210
def configure_ddp(
211211
self, model: LightningModule, device_ids: List[int]
212212
) -> DistributedDataParallel:
213+
self.ddp_plugin.device_ids = device_ids
213214
model = self.ddp_plugin.configure_ddp(model, device_ids)
214215
return model
215216

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def ddp_train(self, process_idx, model):
315315
def configure_ddp(
316316
self, model: LightningModule, device_ids: List[int]
317317
) -> DistributedDataParallel:
318+
self.ddp_plugin.device_ids = device_ids
318319
model = self.ddp_plugin.configure_ddp(model, device_ids)
319320
return model
320321

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
239239
def configure_ddp(
240240
self, model: LightningModule, device_ids: List[int]
241241
) -> DistributedDataParallel:
242+
self.ddp_plugin.device_ids = device_ids
242243
model = self.ddp_plugin.configure_ddp(model, device_ids)
243244
return model
244245

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def ddp_train(self, process_idx, model):
199199
def configure_ddp(
200200
self, model: LightningModule, device_ids: List[int]
201201
) -> DistributedDataParallel:
202+
self.ddp_plugin.device_ids = device_ids
202203
model = self.ddp_plugin.configure_ddp(model, device_ids)
203204
return model
204205

pytorch_lightning/accelerators/ddp_spawn_accelerator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
271271
def configure_ddp(
272272
self, model: LightningModule, device_ids: List[int]
273273
) -> DistributedDataParallel:
274+
self.ddp_plugin.device_ids = device_ids
274275
model = self.ddp_plugin.configure_ddp(model, device_ids)
275276
return model
276277

pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,9 @@ def transfer_batch_to_device(self, batch, device)
562562
any other device than the one passed in as argument (unless you know what you are doing).
563563
564564
Note:
565-
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
566-
for your custom batch objects, you need to define your custom
567-
:class:`~torch.nn.parallel.DistributedDataParallel` and
565+
This hook only runs on single GPU training and DDP (no data-parallel).
566+
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
567+
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
568568
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
569569
570570
See Also:

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,22 +110,23 @@ def init_ddp_connection(
110110
torch_backend, rank=global_rank, world_size=world_size
111111
)
112112

113+
@property
114+
def is_running_single_process_per_device(self) -> bool:
115+
# objects do not need to be scattered in single process per device, move objects upfront to device
116+
# This property is used in ``self.on_before_forward`` function.
117+
return self.device_ids is not None and len(self.device_ids) == 1
118+
113119
def on_before_forward(self, model: LightningModule, *args):
114120
"""
115-
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
116-
within the DDP wrapper.
117-
118-
Example::
119-
120-
def on_before_forward(self, model, *args):
121-
batch, batch_idx = args
122-
return batch.to(model.device)
121+
Override to handle custom edge case.
123122
124123
Args:
125124
args: Inputs to the model.
126125
model: Model to train.
127126
Returns: args moved to correct device if needed.
128127
"""
128+
if self.is_running_single_process_per_device:
129+
args = model.transfer_batch_to_device(args, model.device)
129130
return args
130131

131132
def optimizer_state(self, optimizer: Optimizer) -> dict:

pytorch_lightning/plugins/ddp_sequential_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from torch import nn
2020
from torch.nn.parallel import DistributedDataParallel
2121

22-
from pytorch_lightning import LightningModule
2322
from pytorch_lightning import _logger as log
23+
from pytorch_lightning import LightningModule
2424
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2525
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException

pytorch_lightning/plugins/sharded_plugin.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
4242
optimizer.consolidate_state_dict()
4343
return self._optim_state_dict(optimizer)
4444

45-
def on_before_forward(self, model: LightningModule, *args):
46-
return model.transfer_batch_to_device(args, model.trainer.root_gpu)
47-
4845
def _check_fairscale(self):
4946
if not _FAIRSCALE_AVAILABLE:
5047
raise MisconfigurationException(

0 commit comments

Comments
 (0)