|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +import sys |
| 4 | +from time import sleep |
| 5 | +from typing import Any, Dict, Optional, Union |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +import torch.distributed as torch_distrib |
| 10 | + |
| 11 | +from pytorch_lightning import _logger as log |
| 12 | +from pytorch_lightning.plugins .training_type.parallel import ParallelPlugin |
| 13 | +from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment |
| 14 | +from pytorch_lightning.distributed import LightningDistributed |
| 15 | +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, unwrap_lightning_module |
| 16 | +from pytorch_lightning.utilities import _HYDRA_AVAILABLE |
| 17 | +from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available |
| 18 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 19 | +from pytorch_lightning.utilities.seed import seed_everything |
| 20 | + |
| 21 | +if _HYDRA_AVAILABLE: |
| 22 | + from hydra.utils import to_absolute_path, get_original_cwd |
| 23 | + from hydra.core.hydra_config import HydraConfig |
| 24 | + |
| 25 | +if torch.distributed.is_available(): |
| 26 | + from torch.distributed import ReduceOp |
| 27 | +else: |
| 28 | + |
| 29 | + class ReduceOp: |
| 30 | + SUM = None |
| 31 | + |
| 32 | + |
| 33 | +class DDPPlugin(ParallelPlugin): |
| 34 | + |
| 35 | + distributed_backend = "ddp" |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + parallel_devices, |
| 40 | + num_nodes=1, |
| 41 | + cluster_environment: ClusterEnvironment = None, |
| 42 | + sync_batchnorm=False, |
| 43 | + **kwargs: Dict[str, Any], |
| 44 | + ) -> None: |
| 45 | + super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) |
| 46 | + self.interactive_ddp_procs = [] |
| 47 | + self.num_nodes = num_nodes |
| 48 | + self.sync_batchnorm = sync_batchnorm |
| 49 | + self.dist = LightningDistributed() |
| 50 | + self._ddp_kwargs = kwargs |
| 51 | + self._has_spawned_children = False |
| 52 | + self.task_idx = None |
| 53 | + self.node_rank = 0 |
| 54 | + self.num_processes = len(parallel_devices) |
| 55 | + |
| 56 | + @property |
| 57 | + def root_device(self): |
| 58 | + return self.parallel_devices[self.local_rank] |
| 59 | + |
| 60 | + @property |
| 61 | + def lightning_module(self): |
| 62 | + # the model may not be wrapped with DistributedDataParallel if calling this too early |
| 63 | + return unwrap_lightning_module(self._model) |
| 64 | + |
| 65 | + @property |
| 66 | + def distributed_sampler_kwargs(self): |
| 67 | + distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) |
| 68 | + return distributed_sampler_kwargs |
| 69 | + |
| 70 | + def setup(self, model): |
| 71 | + self._model = model |
| 72 | + |
| 73 | + # start the other scripts |
| 74 | + # TODO: make sure this works, in torchelastic we should not launch child processes! |
| 75 | + if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": |
| 76 | + self._call_children_scripts() |
| 77 | + |
| 78 | + # set the task idx |
| 79 | + self.task_idx = self.cluster_environment.local_rank() |
| 80 | + |
| 81 | + def _call_children_scripts(self): |
| 82 | + |
| 83 | + # bookkeeping of spawned processes |
| 84 | + assert self.global_rank == 0 |
| 85 | + self._check_can_spawn_children() |
| 86 | + self._has_spawned_children = True |
| 87 | + |
| 88 | + # DDP Environment variables |
| 89 | + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") |
| 90 | + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) |
| 91 | + |
| 92 | + # allow the user to pass the node rank |
| 93 | + node_rank = "0" |
| 94 | + node_rank = os.environ.get("NODE_RANK", node_rank) |
| 95 | + node_rank = os.environ.get("GROUP_RANK", node_rank) |
| 96 | + os.environ["NODE_RANK"] = node_rank |
| 97 | + os.environ["LOCAL_RANK"] = "0" |
| 98 | + |
| 99 | + # when user is using hydra find the absolute path |
| 100 | + path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path |
| 101 | + |
| 102 | + # pull out the commands used to run the script and resolve the abs file path |
| 103 | + command = sys.argv |
| 104 | + try: |
| 105 | + full_path = path_lib(command[0]) |
| 106 | + except Exception as e: |
| 107 | + full_path = os.path.abspath(command[0]) |
| 108 | + |
| 109 | + command[0] = full_path |
| 110 | + # use the same python interpreter and actually running |
| 111 | + command = [sys.executable] + command |
| 112 | + |
| 113 | + # the visible devices tell us how many GPUs we want to use. |
| 114 | + # when the trainer script was called the device has already been scoped by the time |
| 115 | + # code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone |
| 116 | + # but forward the GPUs selected via environment variables |
| 117 | + if self.parallel_devices is None: |
| 118 | + raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") |
| 119 | + |
| 120 | + os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) |
| 121 | + os.environ["PL_IN_DDP_SUBPROCESS"] = "1" |
| 122 | + |
| 123 | + if self.lightning_module.logger is not None: |
| 124 | + os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) |
| 125 | + |
| 126 | + num_gpus = len(self.parallel_devices) |
| 127 | + os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" |
| 128 | + |
| 129 | + self.interactive_ddp_procs = [] |
| 130 | + |
| 131 | + for local_rank in range(1, self.num_processes): |
| 132 | + env_copy = os.environ.copy() |
| 133 | + env_copy["LOCAL_RANK"] = f"{local_rank}" |
| 134 | + |
| 135 | + # remove env var if global seed not set |
| 136 | + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: |
| 137 | + del env_copy["PL_GLOBAL_SEED"] |
| 138 | + |
| 139 | + # start process |
| 140 | + # if hydra is available and initialized, make sure to set the cwd correctly |
| 141 | + cwd: Optional[str] = None |
| 142 | + if _HYDRA_AVAILABLE: |
| 143 | + if HydraConfig.initialized(): |
| 144 | + cwd = get_original_cwd() |
| 145 | + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) |
| 146 | + self.interactive_ddp_procs.append(proc) |
| 147 | + |
| 148 | + # starting all processes at once can cause issues |
| 149 | + # with dataloaders delay between 1-10 seconds |
| 150 | + delay = np.random.uniform(1, 5, 1)[0] |
| 151 | + sleep(delay) |
| 152 | + |
| 153 | + def _check_can_spawn_children(self): |
| 154 | + if self._has_spawned_children: |
| 155 | + raise RuntimeError( |
| 156 | + "You tried to run `.fit` or `.test` multiple times in the same script." |
| 157 | + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." |
| 158 | + ) |
| 159 | + |
| 160 | + def set_world_ranks(self): |
| 161 | + self.local_rank = self.task_idx |
| 162 | + self.node_rank = self.cluster_environment.node_rank() |
| 163 | + self.global_rank = self.node_rank * self.num_processes + self.local_rank |
| 164 | + self.world_size = self.num_nodes * self.num_processes |
| 165 | + |
| 166 | + def configure_ddp(self): |
| 167 | + # if unset, default `find_unused_parameters` `True` |
| 168 | + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) |
| 169 | + self._model = LightningDistributedDataParallel( |
| 170 | + self.model, |
| 171 | + device_ids=self.determine_ddp_device_ids(), |
| 172 | + **self._ddp_kwargs, |
| 173 | + ) |
| 174 | + |
| 175 | + def determine_ddp_device_ids(self): |
| 176 | + if self.root_device.type == "cpu": |
| 177 | + return None |
| 178 | + return [self.root_device.index] |
| 179 | + |
| 180 | + def init_ddp_connection(self, global_rank: int, world_size: int) -> None: |
| 181 | + # TODO: From where to get cluster environment? |
| 182 | + os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) |
| 183 | + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) |
| 184 | + os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) |
| 185 | + torch_backend = "nccl" if self.on_gpu else "gloo" |
| 186 | + |
| 187 | + if not torch.distributed.is_initialized(): |
| 188 | + log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") |
| 189 | + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) |
| 190 | + |
| 191 | + def pre_training(self): |
| 192 | + # TODO: check if needed |
| 193 | + seed = os.environ.get("PL_GLOBAL_SEED") |
| 194 | + if seed is not None: |
| 195 | + seed_everything(int(seed)) |
| 196 | + |
| 197 | + # determine which process we are and world size |
| 198 | + self.set_world_ranks() |
| 199 | + |
| 200 | + # set warning rank |
| 201 | + rank_zero_only.rank = self.global_rank |
| 202 | + |
| 203 | + # set up server using proc 0's ip address |
| 204 | + # try to init for 20 times at max in case ports are taken |
| 205 | + # where to store ip_table |
| 206 | + self.init_ddp_connection(self.global_rank, self.world_size) |
| 207 | + |
| 208 | + # TODO: we moved it to the trainer.fit after calling pre_training |
| 209 | + # ... need to double check that it is the correct place |
| 210 | + # self.trainer.call_setup_hook(self.model) |
| 211 | + |
| 212 | + # on world_size=0 let everyone know training is starting |
| 213 | + if self.is_global_zero and not torch.distributed.is_initialized(): |
| 214 | + log.info("-" * 100) |
| 215 | + log.info(f"distributed_backend={self.distributed_backend}") |
| 216 | + log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") |
| 217 | + log.info("-" * 100) |
| 218 | + |
| 219 | + # set the ranks and devices |
| 220 | + self.dist.rank = self.global_rank |
| 221 | + self.dist.device = self.root_device |
| 222 | + |
| 223 | + if self.sync_batchnorm: |
| 224 | + self.model = self.configure_sync_batchnorm(self.model) |
| 225 | + |
| 226 | + # move the model to the correct device |
| 227 | + self.model_to_device() |
| 228 | + |
| 229 | + self.configure_ddp() |
| 230 | + |
| 231 | + self.barrier() |
| 232 | + |
| 233 | + def post_training(self): |
| 234 | + if "WORLD_SIZE" in os.environ: |
| 235 | + del os.environ["WORLD_SIZE"] |
| 236 | + |
| 237 | + def barrier(self, *args, **kwargs): |
| 238 | + if torch_distrib.is_initialized(): |
| 239 | + torch_distrib.barrier() |
| 240 | + |
| 241 | + def broadcast(self, obj: object, src: int = 0) -> object: |
| 242 | + return self.dist.broadcast(obj) |
| 243 | + |
| 244 | + def model_to_device(self): |
| 245 | + if self.root_device.type == "cuda": |
| 246 | + torch.cuda.set_device(self.root_device) |
| 247 | + self.model.to(self.root_device) |
| 248 | + |
| 249 | + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): |
| 250 | + if isinstance(output, torch.Tensor): |
| 251 | + output = sync_ddp_if_available(output, group, reduce_op) |
| 252 | + return output |
0 commit comments