Skip to content

Commit 4aaca17

Browse files
awaelchlitchaton
andauthored
Update setup logic in training type plugins (data-parallel) [3 / n] (#10010)
Co-authored-by: thomas chaton <[email protected]>
1 parent 854bdc0 commit 4aaca17

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
203203
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
204204
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
205205
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
206+
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
206207

207208
### Changed
208209

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import List, Optional
1515

1616
import torch
17-
from torch.nn import DataParallel
17+
from torch.nn import DataParallel, Module
1818

1919
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
2020
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
@@ -54,7 +54,11 @@ def world_size(self) -> int:
5454
def setup(self) -> None:
5555
# model needs to be moved to the device before it is wrapped
5656
self.model_to_device()
57-
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)
57+
self._model = self._setup_model(LightningParallelModule(self._model))
58+
59+
def _setup_model(self, model: Module) -> DataParallel:
60+
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
61+
return DataParallel(module=model, device_ids=self.parallel_devices)
5862

5963
def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
6064
"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor.

0 commit comments

Comments
 (0)