Skip to content

Commit ee9b3fe

Browse files
SeanNarentchaton
andauthored
[feat] pp 1/n (#5016)
* Added changes for RPC plugin * Add missing kwargs * Fix code format * Loading refactors by introducing is_distributed var, fix optimizer step flow * Add rpc guard * Added docstrings and typing * resolve comments * Add additional rpc hook, refactor name of exit process hook for clarity * remove annotation * Modify behaviour to allow optional return, add test for rpc plugin * resolve tests * rename is_ddp_based * update * update for windows * update * resolve test * code smell * Revert back to init_ddp_connection for backwards compat * Swap to explicit name for property * Add missing speed parity increase for CI variability, fix call counts for child process Co-authored-by: tchaton <[email protected]>
1 parent ddd3eda commit ee9b3fe

23 files changed

+560
-60
lines changed

benchmarks/test_sharded_parity.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
161161
gpus=2,
162162
accelerator='ddp_spawn',
163163
model_cls=SeedTrainLoaderManualModel,
164+
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
164165
)
165166

166167

pytorch_lightning/accelerators/accelerator.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from contextlib import contextmanager
1615
from enum import Enum
1716
from typing import Any, Optional, Union
@@ -21,10 +20,8 @@
2120
from torch.optim import Optimizer
2221

2322
from pytorch_lightning.core.lightning import LightningModule
24-
from pytorch_lightning.core.optimizer import LightningOptimizer
25-
from pytorch_lightning.utilities import AMPType
23+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2624
from pytorch_lightning.utilities.apply_func import move_data_to_device
27-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2825
from pytorch_lightning.utilities.parsing import AttributeDict
2926

3027
if torch.distributed.is_available():
@@ -222,6 +219,18 @@ def __setstate__(self, d):
222219
def on_save(self, checkpoint):
223220
return checkpoint
224221

222+
@property
223+
def rpc_enabled(self):
224+
return self.ddp_plugin is not None and isinstance(self.ddp_plugin, RPCPlugin)
225+
226+
@property
227+
def distributed_sampler_kwargs(self):
228+
raise NotImplementedError
229+
230+
@property
231+
def require_distributed_sampler(self):
232+
raise NotImplementedError
233+
225234
@contextmanager
226235
def block_ddp_plugin_sync_behaviour(self):
227236
"""

pytorch_lightning/accelerators/cpu_accelerator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,7 @@ def sync_tensor(self,
9090
group: Optional[Any] = None,
9191
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
9292
return tensor
93+
94+
@property
95+
def require_distributed_sampler(self):
96+
return False

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.core.lightning import LightningModule
2424
from pytorch_lightning.core.step_result import Result
2525
from pytorch_lightning.distributed.dist import LightningDistributed
26+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2627
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2728
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2829

@@ -101,9 +102,11 @@ def set_world_ranks(self, process_idx):
101102
def broadcast(self, obj, src=0):
102103
return self.dist.broadcast(obj)
103104

104-
def model_to_device(self, model, process_idx):
105+
def init_device(self, process_idx):
105106
self.trainer.root_gpu = process_idx
106107
torch.cuda.set_device(self.trainer.root_gpu)
108+
109+
def model_to_device(self, model):
107110
model.cuda(self.trainer.root_gpu)
108111

109112
def get_device_ids(self):
@@ -133,6 +136,9 @@ def ddp_train(self, process_idx, mp_queue, model):
133136
# set warning rank
134137
rank_zero_only.rank = self.trainer.global_rank
135138

139+
# Initialize cuda device
140+
self.init_device(process_idx)
141+
136142
# set up server using proc 0's ip address
137143
# try to init for 20 times at max in case ports are taken
138144
# where to store ip_table
@@ -143,6 +149,15 @@ def ddp_train(self, process_idx, mp_queue, model):
143149
self.trainer.is_slurm_managing_tasks
144150
)
145151

152+
if isinstance(self.ddp_plugin, RPCPlugin):
153+
if not self.ddp_plugin.is_main_rpc_process:
154+
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
155+
self.ddp_plugin.exit_rpc_process()
156+
if self.ddp_plugin.return_after_exit_rpc_process:
157+
return
158+
else:
159+
self.ddp_plugin.on_main_rpc_connection(self.trainer)
160+
146161
# call setup after the ddp process has connected
147162
self.trainer.call_setup_hook(model)
148163

@@ -158,12 +173,14 @@ def ddp_train(self, process_idx, mp_queue, model):
158173
model = self.configure_sync_batchnorm(model)
159174

160175
# move the model to the correct device
161-
self.model_to_device(model, process_idx)
176+
self.model_to_device(model)
162177

163178
# CHOOSE OPTIMIZER
164179
# allow for lr schedulers as well
165180
self.setup_optimizers(model)
166181

182+
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
183+
167184
# set model properties before going into wrapper
168185
self.trainer.model_connector.copy_trainer_model_properties(model)
169186

@@ -189,7 +206,7 @@ def ddp_train(self, process_idx, mp_queue, model):
189206
return results
190207

191208
def configure_ddp(
192-
self, model: LightningModule, device_ids: List[int]
209+
self, model: LightningModule, device_ids: List[int]
193210
) -> DistributedDataParallel:
194211
model = self.ddp_plugin.configure_ddp(model, device_ids)
195212
return model
@@ -219,3 +236,17 @@ def sync_tensor(self,
219236

220237
def get_reference_model(self, model) -> LightningModule:
221238
return self.ddp_plugin.get_model_from_plugin(model)
239+
240+
@property
241+
def distributed_sampler_kwargs(self):
242+
distributed_sampler_kwargs = dict(
243+
num_replicas=self.trainer.num_nodes,
244+
rank=self.trainer.global_rank
245+
)
246+
if self.ddp_plugin is not None:
247+
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
248+
return distributed_sampler_kwargs
249+
250+
@property
251+
def require_distributed_sampler(self):
252+
return True

pytorch_lightning/accelerators/ddp_accelerator.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2828
from pytorch_lightning.core.lightning import LightningModule
2929
from pytorch_lightning.distributed.dist import LightningDistributed
30+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
3031
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
3132
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
3233
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -162,8 +163,11 @@ def _step(self, args):
162163
return output
163164

164165
def barrier(self, name: Optional[str] = None):
165-
if torch_distrib.is_initialized():
166-
torch_distrib.barrier()
166+
if self.rpc_enabled:
167+
# Allow RPC to handle barrier on main RPC processes
168+
self.ddp_plugin.barrier()
169+
elif torch_distrib.is_initialized():
170+
torch_distrib.barrier(group=self.ddp_plugin.data_parallel_group)
167171

168172
def _check_can_spawn_children(self):
169173
if self._has_spawned_children:
@@ -177,9 +181,11 @@ def set_world_ranks(self, process_idx):
177181
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
178182
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes
179183

180-
def model_to_device(self, model, process_idx):
184+
def init_device(self, process_idx):
181185
self.trainer.root_gpu = self.trainer.data_parallel_device_ids[self.trainer.local_rank]
182186
torch.cuda.set_device(self.trainer.root_gpu)
187+
188+
def model_to_device(self, model):
183189
model.cuda(self.trainer.root_gpu)
184190

185191
def get_device_ids(self):
@@ -192,12 +198,12 @@ def on_train_end(self):
192198
def early_stopping_should_stop(self, pl_module):
193199
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
194200
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
195-
torch_distrib.barrier()
201+
self.barrier('early_stopping')
196202
should_stop = stop == self.trainer.world_size
197203
return should_stop
198204

199205
def broadcast(self, obj, src=0):
200-
return self.dist.broadcast(obj)
206+
return self.dist.broadcast(obj, group=self.ddp_plugin.data_parallel_group)
201207

202208
def ddp_train(self, process_idx, model):
203209
"""
@@ -226,6 +232,9 @@ def ddp_train(self, process_idx, model):
226232
# set warning rank
227233
rank_zero_only.rank = self.trainer.global_rank
228234

235+
# Initialize cuda device
236+
self.init_device(process_idx)
237+
229238
# set up server using proc 0's ip address
230239
# try to init for 20 times at max in case ports are taken
231240
# where to store ip_table
@@ -236,6 +245,15 @@ def ddp_train(self, process_idx, model):
236245
self.trainer.is_slurm_managing_tasks
237246
)
238247

248+
if isinstance(self.ddp_plugin, RPCPlugin):
249+
if not self.ddp_plugin.is_main_rpc_process:
250+
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
251+
self.ddp_plugin.exit_rpc_process()
252+
if self.ddp_plugin.return_after_exit_rpc_process:
253+
return
254+
else:
255+
self.ddp_plugin.on_main_rpc_connection(self.trainer)
256+
239257
# call setup after the ddp process has connected
240258
self.trainer.call_setup_hook(model)
241259

@@ -251,7 +269,7 @@ def ddp_train(self, process_idx, model):
251269
model = self.configure_sync_batchnorm(model)
252270

253271
# move the model to the correct device
254-
self.model_to_device(model, process_idx)
272+
self.model_to_device(model)
255273

256274
# CHOOSE OPTIMIZER
257275
# allow for lr schedulers as well
@@ -284,7 +302,7 @@ def ddp_train(self, process_idx, model):
284302
return results
285303

286304
def configure_ddp(
287-
self, model: LightningModule, device_ids: List[int]
305+
self, model: LightningModule, device_ids: List[int]
288306
) -> DistributedDataParallel:
289307
model = self.ddp_plugin.configure_ddp(model, device_ids)
290308
return model
@@ -317,3 +335,17 @@ def sync_tensor(self,
317335

318336
def get_reference_model(self, model) -> LightningModule:
319337
return self.ddp_plugin.get_model_from_plugin(model)
338+
339+
@property
340+
def distributed_sampler_kwargs(self):
341+
distributed_sampler_kwargs = dict(
342+
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
343+
rank=self.trainer.global_rank
344+
)
345+
if self.ddp_plugin is not None:
346+
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
347+
return distributed_sampler_kwargs
348+
349+
@property
350+
def require_distributed_sampler(self):
351+
return True

pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2525
from pytorch_lightning.core.lightning import LightningModule
2626
from pytorch_lightning.distributed.dist import LightningDistributed
27+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2728
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2829
from pytorch_lightning.utilities.distributed import (
2930
find_free_network_port,
@@ -107,6 +108,15 @@ def ddp_train(self, process_idx, mp_queue, model):
107108
self.trainer.is_slurm_managing_tasks
108109
)
109110

111+
if isinstance(self.ddp_plugin, RPCPlugin):
112+
if not self.ddp_plugin.is_main_rpc_process:
113+
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
114+
self.ddp_plugin.exit_rpc_process()
115+
if self.ddp_plugin.return_after_exit_rpc_process:
116+
return
117+
else:
118+
self.ddp_plugin.on_main_rpc_connection(self.trainer)
119+
110120
# call setup after the ddp process has connected
111121
self.trainer.call_setup_hook(model)
112122

@@ -128,6 +138,8 @@ def ddp_train(self, process_idx, mp_queue, model):
128138
# allow for lr schedulers as well
129139
self.setup_optimizers(model)
130140

141+
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
142+
131143
# set model properties before going into wrapper
132144
self.trainer.model_connector.copy_trainer_model_properties(model)
133145

@@ -221,7 +233,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
221233
mp_queue.put(results)
222234

223235
def configure_ddp(
224-
self, model: LightningModule, device_ids: List[int]
236+
self, model: LightningModule, device_ids: List[int]
225237
) -> DistributedDataParallel:
226238
model = self.ddp_plugin.configure_ddp(model, device_ids)
227239
return model
@@ -251,3 +263,17 @@ def sync_tensor(self,
251263

252264
def get_reference_model(self, model) -> LightningModule:
253265
return self.ddp_plugin.get_model_from_plugin(model)
266+
267+
@property
268+
def distributed_sampler_kwargs(self):
269+
distributed_sampler_kwargs = dict(
270+
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
271+
rank=self.trainer.global_rank
272+
)
273+
if self.ddp_plugin is not None:
274+
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
275+
return distributed_sampler_kwargs
276+
277+
@property
278+
def require_distributed_sampler(self):
279+
return True

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2424
from pytorch_lightning.core.lightning import LightningModule
2525
from pytorch_lightning.distributed.dist import LightningDistributed
26+
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
2627
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
2728
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2829

@@ -62,9 +63,11 @@ def set_world_ranks(self, process_idx):
6263
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
6364
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes
6465

65-
def model_to_device(self, model, process_idx):
66+
def init_device(self, process_idx):
6667
self.trainer.root_gpu = process_idx
6768
torch.cuda.set_device(self.trainer.root_gpu)
69+
70+
def model_to_device(self, model):
6871
model.cuda(self.trainer.root_gpu)
6972

7073
def get_device_ids(self):
@@ -136,6 +139,15 @@ def ddp_train(self, process_idx, model):
136139
self.trainer.is_slurm_managing_tasks
137140
)
138141

142+
if isinstance(self.ddp_plugin, RPCPlugin):
143+
if not self.ddp_plugin.is_main_rpc_process:
144+
self.ddp_plugin.on_accelerator_exit_rpc_process(self.trainer)
145+
self.ddp_plugin.exit_rpc_process()
146+
if self.ddp_plugin.return_after_exit_rpc_process:
147+
return
148+
else:
149+
self.ddp_plugin.on_main_rpc_connection(self.trainer)
150+
139151
# call setup after the ddp process has connected
140152
self.trainer.call_setup_hook(model)
141153

@@ -151,12 +163,14 @@ def ddp_train(self, process_idx, model):
151163
model = self.configure_sync_batchnorm(model)
152164

153165
# move the model to the correct device
154-
self.model_to_device(model, process_idx)
166+
self.model_to_device(model)
155167

156168
# CHOOSE OPTIMIZER
157169
# allow for lr schedulers as well
158170
self.setup_optimizers(model)
159171

172+
self.ddp_plugin.on_after_setup_optimizers(self.trainer)
173+
160174
# set model properties before going into wrapper
161175
self.trainer.model_connector.copy_trainer_model_properties(model)
162176

@@ -183,7 +197,7 @@ def ddp_train(self, process_idx, model):
183197
return results
184198

185199
def configure_ddp(
186-
self, model: LightningModule, device_ids: List[int]
200+
self, model: LightningModule, device_ids: List[int]
187201
) -> DistributedDataParallel:
188202
model = self.ddp_plugin.configure_ddp(model, device_ids)
189203
return model
@@ -213,3 +227,17 @@ def sync_tensor(self,
213227

214228
def get_reference_model(self, model) -> LightningModule:
215229
return self.ddp_plugin.get_model_from_plugin(model)
230+
231+
@property
232+
def distributed_sampler_kwargs(self):
233+
distributed_sampler_kwargs = dict(
234+
num_replicas=self.trainer.num_nodes * self.trainer.num_processes,
235+
rank=self.trainer.global_rank
236+
)
237+
if self.ddp_plugin is not None:
238+
distributed_sampler_kwargs = self.ddp_plugin.distributed_sampler_kwargs(distributed_sampler_kwargs)
239+
return distributed_sampler_kwargs
240+
241+
@property
242+
def require_distributed_sampler(self):
243+
return True

0 commit comments

Comments
 (0)