Skip to content

Commit 4e9b453

Browse files
SeanNarenawaelchli
andauthored
[Fix] Move init dist connection into the setup function (#6506)
* Move connection setup into the setup function. Call setup hook after we set up the accelerator * Added CHANGELOG.md * fix setup order in callback test * fix input arguments in test * Mock distributed function, remove protection to turn into training type hook * Remove import * Add missing mock, ensure custom plugin does not create children process * Skip test on windows * Update deepspeed to init connection in setup * Do not initialize distributed module * Move DeepSpeed tests to special tests since dist communication is being set up * Special the test to see if this fixes CI * Delete accelerator connector test to see if its causing build to fail * Delete deepspeed test * Revert "Delete accelerator connector test to see if its causing build to fail" This reverts commit edde60b * Revert "Delete deepspeed test" This reverts commit 9d317429 * Reverse hook * Reverse setup hooks to debug again * Add todo so i know where i left off * For single device move in pre_dispatch after setup function * Add additional model to device hook if any additional parameters have been set * See if we can enable deepspeed tests * Revert "See if we can enable deepspeed tests" This reverts commit b5450de * See if this hook approach works * Introduce new granular hooks * Remove import, fix tpu spawn by moving the function to setup * Added missing special test Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b606171 commit 4e9b453

File tree

16 files changed

+139
-100
lines changed

16 files changed

+139
-100
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
141141
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
142142

143143

144+
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
145+
146+
144147
## [1.2.4] - 2021-03-16
145148

146149
### Changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,28 @@ def __init__(
6565
self.lr_schedulers: Sequence = []
6666
self.optimizer_frequencies: Sequence = []
6767

68-
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
68+
def connect(self, model: LightningModule) -> None:
69+
"""Transfers ownership of the model to this plugin"""
70+
self.training_type_plugin.connect(model)
71+
72+
def setup_environment(self) -> None:
6973
"""
70-
Connects the plugins to the training process, creates optimizers
74+
Setup any processes or distributed connections.
75+
This is called before the LightningModule/DataModule setup hook
76+
which allows the user to access the accelerator environment before setup is complete.
77+
"""
78+
self.training_type_plugin.setup_environment()
7179

80+
def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
81+
"""
82+
Setup plugins for the trainer fit and creates optimizers.
7283
Args:
73-
trainer: the trainer instance to connect to
74-
model: the model to train
84+
trainer: the trainer instance
85+
model: the LightningModule
7586
"""
76-
self.connect_training_type_plugin(self.training_type_plugin, model)
87+
self.setup_training_type_plugin(self.training_type_plugin, model)
7788
self.setup_optimizers(trainer)
78-
self.connect_precision_plugin(self.precision_plugin)
89+
self.setup_precision_plugin(self.precision_plugin)
7990

8091
def start_training(self, trainer: 'Trainer') -> None:
8192
self.training_type_plugin.start_training(trainer)
@@ -332,14 +343,11 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
332343
self.lr_schedulers = lr_schedulers
333344
self.optimizer_frequencies = optimizer_frequencies
334345

335-
def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
336-
"""Attaches the training type plugin to the accelerator.
337-
Also transfers ownership of the model to this plugin
338-
339-
"""
340-
plugin.connect(model)
346+
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
347+
"""Attaches the training type plugin to the accelerator."""
348+
plugin.setup(model)
341349

342-
def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
350+
def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None:
343351
"""Attaches the precision plugin to the accelerator"""
344352
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
345353
self.model = model

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -80,16 +80,16 @@ def distributed_sampler_kwargs(self):
8080
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
8181
return distributed_sampler_kwargs
8282

83-
def setup(self, model):
84-
self._model = model
85-
83+
def setup_environment(self):
8684
# start the other scripts
8785
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
8886
self._call_children_scripts()
8987

9088
# set the task idx
9189
self.task_idx = self.cluster_environment.local_rank()
9290

91+
self.setup_distributed()
92+
9393
def _call_children_scripts(self):
9494

9595
# bookkeeping of spawned processes
@@ -161,6 +161,34 @@ def _call_children_scripts(self):
161161
delay = np.random.uniform(1, 5, 1)[0]
162162
sleep(delay)
163163

164+
def setup_distributed(self):
165+
# TODO: check if needed
166+
seed = os.environ.get("PL_GLOBAL_SEED")
167+
if seed is not None:
168+
seed_everything(int(seed))
169+
170+
# determine which process we are and world size
171+
self.set_world_ranks()
172+
173+
# set warning rank
174+
rank_zero_only.rank = self.global_rank
175+
176+
# set up server using proc 0's ip address
177+
# try to init for 20 times at max in case ports are taken
178+
# where to store ip_table
179+
self.init_ddp_connection(self.global_rank, self.world_size)
180+
181+
# on world_size=0 let everyone know training is starting
182+
if self.is_global_zero and not torch.distributed.is_initialized():
183+
log.info("-" * 100)
184+
log.info(f"distributed_backend={self.distributed_backend}")
185+
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
186+
log.info("-" * 100)
187+
188+
# set the ranks and devices
189+
self.dist.rank = self.global_rank
190+
self.dist.device = self.root_device
191+
164192
def _check_can_spawn_children(self):
165193
if self._has_spawned_children:
166194
raise RuntimeError(
@@ -213,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
213241
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
214242

215243
def pre_dispatch(self):
216-
# TODO: check if needed
217-
seed = os.environ.get("PL_GLOBAL_SEED")
218-
if seed is not None:
219-
seed_everything(int(seed))
220-
221-
# determine which process we are and world size
222-
self.set_world_ranks()
223-
224-
# set warning rank
225-
rank_zero_only.rank = self.global_rank
226-
227-
# set up server using proc 0's ip address
228-
# try to init for 20 times at max in case ports are taken
229-
# where to store ip_table
230-
self.init_ddp_connection(self.global_rank, self.world_size)
231-
232-
# TODO: we moved it to the trainer.fit after calling pre_dispatch
233-
# ... need to double check that it is the correct place
234-
# self.trainer.call_setup_hook(self.model)
235-
236-
# on world_size=0 let everyone know training is starting
237-
if self.is_global_zero and not torch.distributed.is_initialized():
238-
log.info("-" * 100)
239-
log.info(f"distributed_backend={self.distributed_backend}")
240-
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
241-
log.info("-" * 100)
242-
243-
# set the ranks and devices
244-
self.dist.rank = self.global_rank
245-
self.dist.device = self.root_device
246-
247244
if self.sync_batchnorm:
248245
self.model = self.configure_sync_batchnorm(self.model)
249246

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def distributed_sampler_kwargs(self):
7777
return distributed_sampler_kwargs
7878

7979
def setup(self, model):
80-
self._model = model
81-
8280
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
8381

8482
# pass in a state q

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,7 @@ def _load_config(self, config):
192192
return config
193193

194194
def pre_dispatch(self):
195-
self.set_world_ranks()
196-
self.init_ddp_connection(self.global_rank, self.world_size)
197-
198195
self.init_deepspeed()
199-
200-
# set warning rank
201-
rank_zero_only.rank = self.global_rank
202-
203-
# set the ranks and devices
204-
self.dist.rank = self.global_rank
205-
self.dist.device = self.root_device
206196
self.barrier()
207197

208198
def init_deepspeed(self):

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,6 @@ def on_gpu(self):
5353
def lightning_module(self):
5454
return unwrap_lightning_module(self._model)
5555

56-
@abstractmethod
57-
def setup(self, model):
58-
raise NotImplementedError
59-
60-
def connect(self, model, *args, **kwargs):
61-
self.setup(model)
62-
return self.model
63-
6456
@property
6557
def is_global_zero(self) -> bool:
6658
return self.global_rank == 0

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def model_to_device(self) -> None:
6464

6565
self._model.to(self.root_device)
6666

67-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
68-
self._model = model
67+
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
6968
self.model_to_device()
7069
return self.model
7170

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,8 @@ def __init__(self, device: Union[torch.device, int]):
3939
def on_tpu(self) -> bool:
4040
return True
4141

42-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
43-
self._model = model
44-
self.model_to_device()
45-
return self._model
46-
4742
def model_to_device(self) -> None:
48-
self._model.to(self.root_device)
43+
self.model.to(self.root_device)
4944

5045
def pre_dispatch(self) -> None:
5146
if isinstance(self.device, int):

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ def __init__(
5353
self.tpu_local_core_rank = 0
5454
self.start_method = None
5555

56-
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
56+
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
5757
self.create_mp_queue()
58-
self._model = model
59-
return self._model
58+
return self.model
6059

6160
def create_mp_queue(self):
6261
self.start_method = 'fork'

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,19 @@ def __init__(self) -> None:
3434
self._model = None
3535
self._results = None
3636

37-
@abstractmethod
3837
def connect(self, model: 'Module') -> None:
39-
"""Called by the accelerator to connect it with this plugin"""
38+
"""Called by the accelerator to connect the accelerator and the model with this plugin"""
39+
self.model = model
40+
41+
def setup_environment(self) -> None:
42+
"""
43+
Setup any processes or distributed connections.
44+
This is called before the LightningModule/DataModule setup hook
45+
which allows the user to access the accelerator environment before setup is complete.
46+
"""
47+
48+
def setup(self, model: 'Module') -> None:
49+
"""Called by the accelerator to finish setup."""
4050

4151
@property
4252
@abstractmethod

0 commit comments

Comments
 (0)