|
| 1 | +# Copyright The PyTorch Lightning team. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | +import subprocess |
| 16 | +import sys |
| 17 | +from time import sleep |
| 18 | +from typing import Any, Dict, Optional, Union |
| 19 | + |
| 20 | +import numpy as np |
| 21 | +import torch |
| 22 | +import torch.distributed as torch_distrib |
| 23 | + |
| 24 | +from pytorch_lightning import _logger as log |
| 25 | +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment |
| 26 | +from pytorch_lightning.distributed import LightningDistributed |
| 27 | +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel |
| 28 | +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin |
| 29 | +from pytorch_lightning.utilities import _HYDRA_AVAILABLE |
| 30 | +from pytorch_lightning.utilities.distributed import ( |
| 31 | + find_free_network_port, |
| 32 | + rank_zero_only, |
| 33 | + ReduceOp, |
| 34 | + sync_ddp_if_available, |
| 35 | +) |
| 36 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 37 | +from pytorch_lightning.utilities.seed import seed_everything |
| 38 | + |
| 39 | +if _HYDRA_AVAILABLE: |
| 40 | + from hydra.core.hydra_config import HydraConfig |
| 41 | + from hydra.utils import get_original_cwd, to_absolute_path |
| 42 | + |
| 43 | + |
| 44 | +class DDPPlugin(ParallelPlugin): |
| 45 | + """ |
| 46 | + Plugin for multi-process single-device training on one or multiple nodes. |
| 47 | +
|
| 48 | + The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, |
| 49 | + where N is the number of devices (e.g. GPU) per node. |
| 50 | + It is very similar to how :mod:`torch.distributed.launch` launches processes. |
| 51 | + """ |
| 52 | + |
| 53 | + distributed_backend = "ddp" |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + parallel_devices, |
| 58 | + num_nodes=1, |
| 59 | + cluster_environment: ClusterEnvironment = None, |
| 60 | + sync_batchnorm=False, |
| 61 | + **kwargs: Dict[str, Any], |
| 62 | + ) -> None: |
| 63 | + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) |
| 64 | + self.interactive_ddp_procs = [] |
| 65 | + self.num_nodes = num_nodes |
| 66 | + self.sync_batchnorm = sync_batchnorm |
| 67 | + self.dist = LightningDistributed() |
| 68 | + self._ddp_kwargs = kwargs |
| 69 | + self._has_spawned_children = False |
| 70 | + self.task_idx = None |
| 71 | + self.node_rank = 0 |
| 72 | + self.num_processes = len(parallel_devices) |
| 73 | + |
| 74 | + @property |
| 75 | + def root_device(self): |
| 76 | + return self.parallel_devices[self.local_rank] |
| 77 | + |
| 78 | + @property |
| 79 | + def lightning_module(self): |
| 80 | + # the model may not be wrapped with DistributedDataParallel if calling this too early |
| 81 | + # fixme: uncomment when this class will actually be used |
| 82 | + # return unwrap_lightning_module(self._model) |
| 83 | + pass |
| 84 | + |
| 85 | + @property |
| 86 | + def distributed_sampler_kwargs(self): |
| 87 | + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) |
| 88 | + return distributed_sampler_kwargs |
| 89 | + |
| 90 | + def setup(self, model): |
| 91 | + self._model = model |
| 92 | + |
| 93 | + # start the other scripts |
| 94 | + # TODO: make sure this works, in torchelastic we should not launch child processes! |
| 95 | + if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": |
| 96 | + self._call_children_scripts() |
| 97 | + |
| 98 | + # set the task idx |
| 99 | + self.task_idx = self.cluster_environment.local_rank() |
| 100 | + |
| 101 | + def _call_children_scripts(self): |
| 102 | + |
| 103 | + # bookkeeping of spawned processes |
| 104 | + assert self.global_rank == 0 |
| 105 | + self._check_can_spawn_children() |
| 106 | + self._has_spawned_children = True |
| 107 | + |
| 108 | + # DDP Environment variables |
| 109 | + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") |
| 110 | + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) |
| 111 | + |
| 112 | + # allow the user to pass the node rank |
| 113 | + node_rank = "0" |
| 114 | + node_rank = os.environ.get("NODE_RANK", node_rank) |
| 115 | + node_rank = os.environ.get("GROUP_RANK", node_rank) |
| 116 | + os.environ["NODE_RANK"] = node_rank |
| 117 | + os.environ["LOCAL_RANK"] = "0" |
| 118 | + |
| 119 | + # when user is using hydra find the absolute path |
| 120 | + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path |
| 121 | + |
| 122 | + # pull out the commands used to run the script and resolve the abs file path |
| 123 | + command = sys.argv |
| 124 | + try: |
| 125 | + full_path = path_lib(command[0]) |
| 126 | + except Exception as e: |
| 127 | + full_path = os.path.abspath(command[0]) |
| 128 | + |
| 129 | + command[0] = full_path |
| 130 | + # use the same python interpreter and actually running |
| 131 | + command = [sys.executable] + command |
| 132 | + |
| 133 | + # the visible devices tell us how many GPUs we want to use. |
| 134 | + # when the trainer script was called the device has already been scoped by the time |
| 135 | + # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone |
| 136 | + # but forward the GPUs selected via environment variables |
| 137 | + if self.parallel_devices is None: |
| 138 | + raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") |
| 139 | + |
| 140 | + os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) |
| 141 | + os.environ["PL_IN_DDP_SUBPROCESS"] = "1" |
| 142 | + |
| 143 | + if self.lightning_module.logger is not None: |
| 144 | + os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) |
| 145 | + |
| 146 | + num_gpus = len(self.parallel_devices) |
| 147 | + os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" |
| 148 | + |
| 149 | + self.interactive_ddp_procs = [] |
| 150 | + |
| 151 | + for local_rank in range(1, self.num_processes): |
| 152 | + env_copy = os.environ.copy() |
| 153 | + env_copy["LOCAL_RANK"] = f"{local_rank}" |
| 154 | + |
| 155 | + # remove env var if global seed not set |
| 156 | + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: |
| 157 | + del env_copy["PL_GLOBAL_SEED"] |
| 158 | + |
| 159 | + # start process |
| 160 | + # if hydra is available and initialized, make sure to set the cwd correctly |
| 161 | + cwd: Optional[str] = None |
| 162 | + if _HYDRA_AVAILABLE: |
| 163 | + if HydraConfig.initialized(): |
| 164 | + cwd = get_original_cwd() |
| 165 | + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) |
| 166 | + self.interactive_ddp_procs.append(proc) |
| 167 | + |
| 168 | + # starting all processes at once can cause issues |
| 169 | + # with dataloaders delay between 1-10 seconds |
| 170 | + delay = np.random.uniform(1, 5, 1)[0] |
| 171 | + sleep(delay) |
| 172 | + |
| 173 | + def _check_can_spawn_children(self): |
| 174 | + if self._has_spawned_children: |
| 175 | + raise RuntimeError( |
| 176 | + "You tried to run `.fit` or `.test` multiple times in the same script." |
| 177 | + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." |
| 178 | + ) |
| 179 | + |
| 180 | + def set_world_ranks(self): |
| 181 | + self.local_rank = self.task_idx |
| 182 | + self.node_rank = self.cluster_environment.node_rank() |
| 183 | + self.global_rank = self.node_rank * self.num_processes + self.local_rank |
| 184 | + self.world_size = self.num_nodes * self.num_processes |
| 185 | + |
| 186 | + def configure_ddp(self): |
| 187 | + # if unset, default `find_unused_parameters` `True` |
| 188 | + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) |
| 189 | + self._model = LightningDistributedDataParallel( |
| 190 | + self.model, |
| 191 | + device_ids=self.determine_ddp_device_ids(), |
| 192 | + **self._ddp_kwargs, |
| 193 | + ) |
| 194 | + |
| 195 | + def determine_ddp_device_ids(self): |
| 196 | + if self.root_device.type == "cpu": |
| 197 | + return None |
| 198 | + return [self.root_device.index] |
| 199 | + |
| 200 | + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: |
| 201 | + # TODO: From where to get cluster environment? |
| 202 | + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) |
| 203 | + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) |
| 204 | + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) |
| 205 | + torch_backend = "nccl" if self.on_gpu else "gloo" |
| 206 | + |
| 207 | + if not torch.distributed.is_initialized(): |
| 208 | + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") |
| 209 | + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) |
| 210 | + |
| 211 | + def pre_training(self): |
| 212 | + # TODO: check if needed |
| 213 | + seed = os.environ.get("PL_GLOBAL_SEED") |
| 214 | + if seed is not None: |
| 215 | + seed_everything(int(seed)) |
| 216 | + |
| 217 | + # determine which process we are and world size |
| 218 | + self.set_world_ranks() |
| 219 | + |
| 220 | + # set warning rank |
| 221 | + rank_zero_only.rank = self.global_rank |
| 222 | + |
| 223 | + # set up server using proc 0's ip address |
| 224 | + # try to init for 20 times at max in case ports are taken |
| 225 | + # where to store ip_table |
| 226 | + self.init_ddp_connection(self.global_rank, self.world_size) |
| 227 | + |
| 228 | + # TODO: we moved it to the trainer.fit after calling pre_training |
| 229 | + # ... need to double check that it is the correct place |
| 230 | + # self.trainer.call_setup_hook(self.model) |
| 231 | + |
| 232 | + # on world_size=0 let everyone know training is starting |
| 233 | + if self.is_global_zero and not torch.distributed.is_initialized(): |
| 234 | + log.info("-" * 100) |
| 235 | + log.info(f"distributed_backend={self.distributed_backend}") |
| 236 | + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") |
| 237 | + log.info("-" * 100) |
| 238 | + |
| 239 | + # set the ranks and devices |
| 240 | + self.dist.rank = self.global_rank |
| 241 | + self.dist.device = self.root_device |
| 242 | + |
| 243 | + if self.sync_batchnorm: |
| 244 | + self.model = self.configure_sync_batchnorm(self.model) |
| 245 | + |
| 246 | + # move the model to the correct device |
| 247 | + self.model_to_device() |
| 248 | + |
| 249 | + self.configure_ddp() |
| 250 | + |
| 251 | + self.barrier() |
| 252 | + |
| 253 | + def post_training(self): |
| 254 | + if "WORLD_SIZE" in os.environ: |
| 255 | + del os.environ["WORLD_SIZE"] |
| 256 | + |
| 257 | + def barrier(self, *args, **kwargs): |
| 258 | + if torch_distrib.is_initialized(): |
| 259 | + torch_distrib.barrier() |
| 260 | + |
| 261 | + def broadcast(self, obj: object, src: int = 0) -> object: |
| 262 | + return self.dist.broadcast(obj) |
| 263 | + |
| 264 | + def model_to_device(self): |
| 265 | + if self.root_device.type == "cuda": |
| 266 | + torch.cuda.set_device(self.root_device) |
| 267 | + self.model.to(self.root_device) |
| 268 | + |
| 269 | + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): |
| 270 | + if isinstance(output, torch.Tensor): |
| 271 | + output = sync_ddp_if_available(output, group, reduce_op) |
| 272 | + return output |
0 commit comments