|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License |
14 | | -import os |
15 | | -from typing import Any, List, Optional, Union |
16 | | - |
17 | | -import torch |
18 | | -import torch.distributed as torch_distrib |
19 | | -import torch.distributed as dist |
20 | | -from torch.nn.parallel import DistributedDataParallel |
21 | | - |
22 | | -from pytorch_lightning import _logger as log |
23 | | -from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp |
24 | | -from pytorch_lightning.core.lightning import LightningModule |
25 | | -from pytorch_lightning.utilities import AMPType |
26 | | -from pytorch_lightning.utilities.distributed import rank_zero_only |
27 | | -from pytorch_lightning.utilities.distributed import sync_ddp_if_available |
28 | | -from pytorch_lightning.distributed.dist import LightningDistributed |
| 14 | +from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator |
29 | 15 |
|
30 | 16 |
|
31 | 17 | try: |
|
37 | 23 | HYDRA_AVAILABLE = True |
38 | 24 |
|
39 | 25 |
|
40 | | -class DDPCPUHPCAccelerator(Accelerator): |
| 26 | +class DDPCPUHPCAccelerator(DDPHPCAccelerator): |
41 | 27 |
|
42 | 28 | def __init__(self, trainer, cluster_environment=None, ddp_plugin=None): |
43 | 29 | super().__init__(trainer, cluster_environment, ddp_plugin) |
44 | | - self.task_idx = None |
45 | | - self._has_spawned_children = False |
46 | | - self.dist = LightningDistributed() |
47 | 30 | self.nickname = 'ddp_cpu' |
48 | 31 |
|
49 | | - def setup(self, model): |
50 | | - self.trainer.model = model |
51 | | - self.task_idx = self.cluster_environment.local_rank() |
52 | | - |
53 | | - def train(self): |
54 | | - model = self.trainer.model |
55 | | - self.ddp_train(process_idx=self.task_idx, model=model) |
56 | | - |
57 | | - def set_world_ranks(self, process_idx): |
58 | | - self.trainer.local_rank = process_idx |
59 | | - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx |
60 | | - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes |
61 | | - |
62 | 32 | def model_to_device(self, model, process_idx): |
63 | 33 | model.cpu() |
64 | 34 |
|
65 | 35 | def get_device_ids(self): |
66 | 36 | device_ids = None |
67 | 37 | return device_ids |
68 | | - |
69 | | - def training_step(self, args): |
70 | | - if self.trainer.amp_backend == AMPType.NATIVE: |
71 | | - with torch.cuda.amp.autocast(): |
72 | | - output = self.trainer.model(*args) |
73 | | - else: |
74 | | - output = self.trainer.model(*args) |
75 | | - return output |
76 | | - |
77 | | - def validation_step(self, args): |
78 | | - output = self.training_step(args) |
79 | | - return output |
80 | | - |
81 | | - def test_step(self, args): |
82 | | - output = self.training_step(args) |
83 | | - return output |
84 | | - |
85 | | - def barrier(self, name: Optional[str] = None): |
86 | | - if torch_distrib.is_initialized(): |
87 | | - torch_distrib.barrier() |
88 | | - |
89 | | - def early_stopping_should_stop(self, pl_module): |
90 | | - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) |
91 | | - dist.all_reduce(stop, op=dist.reduce_op.SUM) |
92 | | - dist.barrier() |
93 | | - should_stop = stop == self.trainer.world_size |
94 | | - return should_stop |
95 | | - |
96 | | - def broadcast(self, obj, src=0): |
97 | | - return self.dist.broadcast(obj) |
98 | | - |
99 | | - def ddp_train(self, process_idx, model): |
100 | | - """ |
101 | | - Entry point for ddp |
102 | | -
|
103 | | - Args: |
104 | | - process_idx: |
105 | | - mp_queue: multiprocessing queue |
106 | | - model: |
107 | | -
|
108 | | - Returns: |
109 | | - Dict with evaluation results |
110 | | -
|
111 | | - """ |
112 | | - # determine which process we are and world size |
113 | | - self.set_world_ranks(process_idx) |
114 | | - |
115 | | - # toggle prog bar |
116 | | - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: |
117 | | - self.trainer.progress_bar_callback.disable() |
118 | | - |
119 | | - # set warning rank |
120 | | - rank_zero_only.rank = self.trainer.global_rank |
121 | | - |
122 | | - # set up server using proc 0's ip address |
123 | | - # try to init for 20 times at max in case ports are taken |
124 | | - # where to store ip_table |
125 | | - model.trainer = self.trainer |
126 | | - self.init_ddp_connection( |
127 | | - self.trainer.global_rank, |
128 | | - self.trainer.world_size, |
129 | | - self.trainer.is_slurm_managing_tasks |
130 | | - ) |
131 | | - |
132 | | - # call setup after the ddp process has connected |
133 | | - self.trainer.call_setup_hook(model) |
134 | | - |
135 | | - # on world_size=0 let everyone know training is starting |
136 | | - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): |
137 | | - log.info('-' * 100) |
138 | | - log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)') |
139 | | - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') |
140 | | - log.info('-' * 100) |
141 | | - |
142 | | - # call sync_bn before .cuda(), configure_apex and configure_ddp |
143 | | - if self.trainer.sync_batchnorm: |
144 | | - model = self.configure_sync_batchnorm(model) |
145 | | - |
146 | | - # move the model to the correct device |
147 | | - self.model_to_device(model, process_idx) |
148 | | - |
149 | | - # CHOOSE OPTIMIZER |
150 | | - # allow for lr schedulers as well |
151 | | - self.setup_optimizers(model) |
152 | | - |
153 | | - # set model properties before going into wrapper |
154 | | - self.trainer.model_connector.copy_trainer_model_properties(model) |
155 | | - |
156 | | - # 16-bit |
157 | | - model = self.trainer.precision_connector.connect(model) |
158 | | - |
159 | | - # device ids change depending on the DDP setup |
160 | | - device_ids = self.get_device_ids() |
161 | | - |
162 | | - # allow user to configure ddp |
163 | | - model = self.configure_ddp(model, device_ids) |
164 | | - |
165 | | - # set up training routine |
166 | | - self.trainer.train_loop.setup_training(model) |
167 | | - |
168 | | - # train or test |
169 | | - results = self.train_or_test() |
170 | | - |
171 | | - # clean up memory |
172 | | - torch.cuda.empty_cache() |
173 | | - |
174 | | - return results |
175 | | - |
176 | | - def configure_ddp( |
177 | | - self, model: LightningModule, device_ids: List[int] |
178 | | - ) -> DistributedDataParallel: |
179 | | - model = self.ddp_plugin.configure_ddp(model, device_ids) |
180 | | - return model |
181 | | - |
182 | | - def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: |
183 | | - """ |
184 | | - Add global batchnorm for a model spread across multiple GPUs and nodes. |
185 | | -
|
186 | | - Override to synchronize batchnorm between specific process groups instead |
187 | | - of the whole world or use a different sync_bn like `apex`'s version. |
188 | | -
|
189 | | - Args: |
190 | | - model: pointer to current :class:`LightningModule`. |
191 | | -
|
192 | | - Return: |
193 | | - LightningModule with batchnorm layers synchronized between process groups |
194 | | - """ |
195 | | - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) |
196 | | - |
197 | | - return model |
198 | | - |
199 | | - def sync_tensor(self, |
200 | | - tensor: Union[torch.Tensor], |
201 | | - group: Optional[Any] = None, |
202 | | - reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: |
203 | | - return sync_ddp_if_available(tensor, group, reduce_op) |
0 commit comments