Skip to content

Commit 06da4fb

Browse files
ddp plugins
Co-authored-by: Justus Schock <[email protected]>
1 parent 21d313e commit 06da4fb

File tree

3 files changed

+507
-0
lines changed

3 files changed

+507
-0
lines changed
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import os
2+
import subprocess
3+
import sys
4+
from time import sleep
5+
from typing import Any, Dict, Optional, Union
6+
7+
import numpy as np
8+
import torch
9+
import torch.distributed as torch_distrib
10+
11+
from pytorch_lightning import _logger as log
12+
from pytorch_lightning.plugins .training_type.parallel import ParallelPlugin
13+
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
14+
from pytorch_lightning.distributed import LightningDistributed
15+
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, unwrap_lightning_module
16+
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
17+
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
18+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
19+
from pytorch_lightning.utilities.seed import seed_everything
20+
21+
if _HYDRA_AVAILABLE:
22+
from hydra.utils import to_absolute_path, get_original_cwd
23+
from hydra.core.hydra_config import HydraConfig
24+
25+
if torch.distributed.is_available():
26+
from torch.distributed import ReduceOp
27+
else:
28+
29+
class ReduceOp:
30+
SUM = None
31+
32+
33+
class DDPPlugin(ParallelPlugin):
34+
35+
distributed_backend = "ddp"
36+
37+
def __init__(
38+
self,
39+
parallel_devices,
40+
num_nodes=1,
41+
cluster_environment: ClusterEnvironment = None,
42+
sync_batchnorm=False,
43+
**kwargs: Dict[str, Any],
44+
) -> None:
45+
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
46+
self.interactive_ddp_procs = []
47+
self.num_nodes = num_nodes
48+
self.sync_batchnorm = sync_batchnorm
49+
self.dist = LightningDistributed()
50+
self._ddp_kwargs = kwargs
51+
self._has_spawned_children = False
52+
self.task_idx = None
53+
self.node_rank = 0
54+
self.num_processes = len(parallel_devices)
55+
56+
@property
57+
def root_device(self):
58+
return self.parallel_devices[self.local_rank]
59+
60+
@property
61+
def lightning_module(self):
62+
# the model may not be wrapped with DistributedDataParallel if calling this too early
63+
return unwrap_lightning_module(self._model)
64+
65+
@property
66+
def distributed_sampler_kwargs(self):
67+
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
68+
return distributed_sampler_kwargs
69+
70+
def setup(self, model):
71+
self._model = model
72+
73+
# start the other scripts
74+
# TODO: make sure this works, in torchelastic we should not launch child processes!
75+
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
76+
self._call_children_scripts()
77+
78+
# set the task idx
79+
self.task_idx = self.cluster_environment.local_rank()
80+
81+
def _call_children_scripts(self):
82+
83+
# bookkeeping of spawned processes
84+
assert self.global_rank == 0
85+
self._check_can_spawn_children()
86+
self._has_spawned_children = True
87+
88+
# DDP Environment variables
89+
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
90+
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))
91+
92+
# allow the user to pass the node rank
93+
node_rank = "0"
94+
node_rank = os.environ.get("NODE_RANK", node_rank)
95+
node_rank = os.environ.get("GROUP_RANK", node_rank)
96+
os.environ["NODE_RANK"] = node_rank
97+
os.environ["LOCAL_RANK"] = "0"
98+
99+
# when user is using hydra find the absolute path
100+
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path
101+
102+
# pull out the commands used to run the script and resolve the abs file path
103+
command = sys.argv
104+
try:
105+
full_path = path_lib(command[0])
106+
except Exception as e:
107+
full_path = os.path.abspath(command[0])
108+
109+
command[0] = full_path
110+
# use the same python interpreter and actually running
111+
command = [sys.executable] + command
112+
113+
# the visible devices tell us how many GPUs we want to use.
114+
# when the trainer script was called the device has already been scoped by the time
115+
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
116+
# but forward the GPUs selected via environment variables
117+
if self.parallel_devices is None:
118+
raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")
119+
120+
os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
121+
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"
122+
123+
if self.lightning_module.logger is not None:
124+
os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)
125+
126+
num_gpus = len(self.parallel_devices)
127+
os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"
128+
129+
self.interactive_ddp_procs = []
130+
131+
for local_rank in range(1, self.num_processes):
132+
env_copy = os.environ.copy()
133+
env_copy["LOCAL_RANK"] = f"{local_rank}"
134+
135+
# remove env var if global seed not set
136+
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
137+
del env_copy["PL_GLOBAL_SEED"]
138+
139+
# start process
140+
# if hydra is available and initialized, make sure to set the cwd correctly
141+
cwd: Optional[str] = None
142+
if _HYDRA_AVAILABLE:
143+
if HydraConfig.initialized():
144+
cwd = get_original_cwd()
145+
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
146+
self.interactive_ddp_procs.append(proc)
147+
148+
# starting all processes at once can cause issues
149+
# with dataloaders delay between 1-10 seconds
150+
delay = np.random.uniform(1, 5, 1)[0]
151+
sleep(delay)
152+
153+
def _check_can_spawn_children(self):
154+
if self._has_spawned_children:
155+
raise RuntimeError(
156+
"You tried to run `.fit` or `.test` multiple times in the same script."
157+
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
158+
)
159+
160+
def set_world_ranks(self):
161+
self.local_rank = self.task_idx
162+
self.node_rank = self.cluster_environment.node_rank()
163+
self.global_rank = self.node_rank * self.num_processes + self.local_rank
164+
self.world_size = self.num_nodes * self.num_processes
165+
166+
def configure_ddp(self):
167+
# if unset, default `find_unused_parameters` `True`
168+
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
169+
self._model = LightningDistributedDataParallel(
170+
self.model,
171+
device_ids=self.determine_ddp_device_ids(),
172+
**self._ddp_kwargs,
173+
)
174+
175+
def determine_ddp_device_ids(self):
176+
if self.root_device.type == "cpu":
177+
return None
178+
return [self.root_device.index]
179+
180+
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
181+
# TODO: From where to get cluster environment?
182+
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
183+
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
184+
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
185+
torch_backend = "nccl" if self.on_gpu else "gloo"
186+
187+
if not torch.distributed.is_initialized():
188+
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
189+
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
190+
191+
def pre_training(self):
192+
# TODO: check if needed
193+
seed = os.environ.get("PL_GLOBAL_SEED")
194+
if seed is not None:
195+
seed_everything(int(seed))
196+
197+
# determine which process we are and world size
198+
self.set_world_ranks()
199+
200+
# set warning rank
201+
rank_zero_only.rank = self.global_rank
202+
203+
# set up server using proc 0's ip address
204+
# try to init for 20 times at max in case ports are taken
205+
# where to store ip_table
206+
self.init_ddp_connection(self.global_rank, self.world_size)
207+
208+
# TODO: we moved it to the trainer.fit after calling pre_training
209+
# ... need to double check that it is the correct place
210+
# self.trainer.call_setup_hook(self.model)
211+
212+
# on world_size=0 let everyone know training is starting
213+
if self.is_global_zero and not torch.distributed.is_initialized():
214+
log.info("-" * 100)
215+
log.info(f"distributed_backend={self.distributed_backend}")
216+
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
217+
log.info("-" * 100)
218+
219+
# set the ranks and devices
220+
self.dist.rank = self.global_rank
221+
self.dist.device = self.root_device
222+
223+
if self.sync_batchnorm:
224+
self.model = self.configure_sync_batchnorm(self.model)
225+
226+
# move the model to the correct device
227+
self.model_to_device()
228+
229+
self.configure_ddp()
230+
231+
self.barrier()
232+
233+
def post_training(self):
234+
if "WORLD_SIZE" in os.environ:
235+
del os.environ["WORLD_SIZE"]
236+
237+
def barrier(self, *args, **kwargs):
238+
if torch_distrib.is_initialized():
239+
torch_distrib.barrier()
240+
241+
def broadcast(self, obj: object, src: int = 0) -> object:
242+
return self.dist.broadcast(obj)
243+
244+
def model_to_device(self):
245+
if self.root_device.type == "cuda":
246+
torch.cuda.set_device(self.root_device)
247+
self.model.to(self.root_device)
248+
249+
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
250+
if isinstance(output, torch.Tensor):
251+
output = sync_ddp_if_available(output, group, reduce_op)
252+
return output
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
3+
from pytorch_lightning.plugins .training_type.ddp import DDPPlugin
4+
from pytorch_lightning.core.step_result import Result
5+
6+
7+
class DDP2Plugin(DDPPlugin):
8+
9+
def setup(self, model):
10+
self._model = model
11+
# set the task idx
12+
self.task_idx = self.cluster_environment.local_rank()
13+
# the difference to DDP is that we don't call children processes here
14+
15+
def reduce(self, output, *args, **kwargs):
16+
if isinstance(output, Result):
17+
output.dp_reduce()
18+
19+
elif isinstance(output, torch.Tensor):
20+
output = output.mean()
21+
22+
return output
23+
24+
@property
25+
def root_device(self):
26+
return self.parallel_devices[0]
27+
28+
def model_to_device(self):
29+
# no need to do anything when model is wrapped in torch.nn.DataParallel
30+
pass
31+
32+
@property
33+
def distributed_sampler_kwargs(self):
34+
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
35+
return distributed_sampler_kwargs
36+
37+
def set_world_ranks(self):
38+
self.local_rank = self.task_idx
39+
self.node_rank = self.cluster_environment.node_rank()
40+
self.global_rank = self.node_rank
41+
self.world_size = self.num_nodes

0 commit comments

Comments
 (0)