Skip to content

Commit f8d1488

Browse files
williamFalconSeanNaren
authored andcommitted
ref: unify slurm and TE under backendPlugin 5/n" (#4582)
* ref: unify slurm and TE under backendPlugin 4/n * ref: unify slurm and TE under backendPlugin 5/n (cherry picked from commit 3ba48d3)
1 parent 6da12dd commit f8d1488

File tree

3 files changed

+12
-174
lines changed

3 files changed

+12
-174
lines changed

pytorch_lightning/accelerators/ddp2_accelerator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
14-
1514
import os
1615

1716
import torch
1817
import torch.distributed as torch_distrib
1918

20-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2119
from pytorch_lightning.core.lightning import LightningModule
2220
from pytorch_lightning.core.step_result import Result
2321
from pytorch_lightning.distributed.dist import LightningDistributed
2422
from pytorch_lightning import _logger as log
25-
from pytorch_lightning.accelerators.accelerator import Accelerator
23+
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
2624
from pytorch_lightning.utilities import AMPType
27-
from pytorch_lightning.utilities.distributed import rank_zero_only
25+
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
2826
from torch.nn.parallel import DistributedDataParallel
29-
from typing import List, Optional
27+
from typing import List, Optional, Union, Any
3028

3129
try:
3230
from hydra.utils import to_absolute_path, get_original_cwd
@@ -203,3 +201,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
203201
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
204202

205203
return model
204+
205+
def sync_tensor(self,
206+
tensor: Union[torch.Tensor],
207+
group: Optional[Any] = None,
208+
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
209+
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_cpu_hpc_accelerator.py

Lines changed: 2 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
14-
import os
15-
from typing import Any, List, Optional, Union
16-
17-
import torch
18-
import torch.distributed as torch_distrib
19-
import torch.distributed as dist
20-
from torch.nn.parallel import DistributedDataParallel
21-
22-
from pytorch_lightning import _logger as log
23-
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
24-
from pytorch_lightning.core.lightning import LightningModule
25-
from pytorch_lightning.utilities import AMPType
26-
from pytorch_lightning.utilities.distributed import rank_zero_only
27-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
28-
from pytorch_lightning.distributed.dist import LightningDistributed
14+
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
2915

3016

3117
try:
@@ -37,167 +23,15 @@
3723
HYDRA_AVAILABLE = True
3824

3925

40-
class DDPCPUHPCAccelerator(Accelerator):
26+
class DDPCPUHPCAccelerator(DDPHPCAccelerator):
4127

4228
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
4329
super().__init__(trainer, cluster_environment, ddp_plugin)
44-
self.task_idx = None
45-
self._has_spawned_children = False
46-
self.dist = LightningDistributed()
4730
self.nickname = 'ddp_cpu'
4831

49-
def setup(self, model):
50-
self.trainer.model = model
51-
self.task_idx = self.cluster_environment.local_rank()
52-
53-
def train(self):
54-
model = self.trainer.model
55-
self.ddp_train(process_idx=self.task_idx, model=model)
56-
57-
def set_world_ranks(self, process_idx):
58-
self.trainer.local_rank = process_idx
59-
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
60-
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes
61-
6232
def model_to_device(self, model, process_idx):
6333
model.cpu()
6434

6535
def get_device_ids(self):
6636
device_ids = None
6737
return device_ids
68-
69-
def training_step(self, args):
70-
if self.trainer.amp_backend == AMPType.NATIVE:
71-
with torch.cuda.amp.autocast():
72-
output = self.trainer.model(*args)
73-
else:
74-
output = self.trainer.model(*args)
75-
return output
76-
77-
def validation_step(self, args):
78-
output = self.training_step(args)
79-
return output
80-
81-
def test_step(self, args):
82-
output = self.training_step(args)
83-
return output
84-
85-
def barrier(self, name: Optional[str] = None):
86-
if torch_distrib.is_initialized():
87-
torch_distrib.barrier()
88-
89-
def early_stopping_should_stop(self, pl_module):
90-
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
91-
dist.all_reduce(stop, op=dist.reduce_op.SUM)
92-
dist.barrier()
93-
should_stop = stop == self.trainer.world_size
94-
return should_stop
95-
96-
def broadcast(self, obj, src=0):
97-
return self.dist.broadcast(obj)
98-
99-
def ddp_train(self, process_idx, model):
100-
"""
101-
Entry point for ddp
102-
103-
Args:
104-
process_idx:
105-
mp_queue: multiprocessing queue
106-
model:
107-
108-
Returns:
109-
Dict with evaluation results
110-
111-
"""
112-
# determine which process we are and world size
113-
self.set_world_ranks(process_idx)
114-
115-
# toggle prog bar
116-
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
117-
self.trainer.progress_bar_callback.disable()
118-
119-
# set warning rank
120-
rank_zero_only.rank = self.trainer.global_rank
121-
122-
# set up server using proc 0's ip address
123-
# try to init for 20 times at max in case ports are taken
124-
# where to store ip_table
125-
model.trainer = self.trainer
126-
self.init_ddp_connection(
127-
self.trainer.global_rank,
128-
self.trainer.world_size,
129-
self.trainer.is_slurm_managing_tasks
130-
)
131-
132-
# call setup after the ddp process has connected
133-
self.trainer.call_setup_hook(model)
134-
135-
# on world_size=0 let everyone know training is starting
136-
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
137-
log.info('-' * 100)
138-
log.info(f'distributed_backend={self.trainer.distributed_backend} (TORCH_ELASTIC)')
139-
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
140-
log.info('-' * 100)
141-
142-
# call sync_bn before .cuda(), configure_apex and configure_ddp
143-
if self.trainer.sync_batchnorm:
144-
model = self.configure_sync_batchnorm(model)
145-
146-
# move the model to the correct device
147-
self.model_to_device(model, process_idx)
148-
149-
# CHOOSE OPTIMIZER
150-
# allow for lr schedulers as well
151-
self.setup_optimizers(model)
152-
153-
# set model properties before going into wrapper
154-
self.trainer.model_connector.copy_trainer_model_properties(model)
155-
156-
# 16-bit
157-
model = self.trainer.precision_connector.connect(model)
158-
159-
# device ids change depending on the DDP setup
160-
device_ids = self.get_device_ids()
161-
162-
# allow user to configure ddp
163-
model = self.configure_ddp(model, device_ids)
164-
165-
# set up training routine
166-
self.trainer.train_loop.setup_training(model)
167-
168-
# train or test
169-
results = self.train_or_test()
170-
171-
# clean up memory
172-
torch.cuda.empty_cache()
173-
174-
return results
175-
176-
def configure_ddp(
177-
self, model: LightningModule, device_ids: List[int]
178-
) -> DistributedDataParallel:
179-
model = self.ddp_plugin.configure_ddp(model, device_ids)
180-
return model
181-
182-
def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
183-
"""
184-
Add global batchnorm for a model spread across multiple GPUs and nodes.
185-
186-
Override to synchronize batchnorm between specific process groups instead
187-
of the whole world or use a different sync_bn like `apex`'s version.
188-
189-
Args:
190-
model: pointer to current :class:`LightningModule`.
191-
192-
Return:
193-
LightningModule with batchnorm layers synchronized between process groups
194-
"""
195-
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)
196-
197-
return model
198-
199-
def sync_tensor(self,
200-
tensor: Union[torch.Tensor],
201-
group: Optional[Any] = None,
202-
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
203-
return sync_ddp_if_available(tensor, group, reduce_op)

pytorch_lightning/accelerators/ddp_hpc_accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def ddp_train(self, process_idx, model):
136136
# on world_size=0 let everyone know training is starting
137137
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
138138
log.info('-' * 100)
139-
log.info(f'distributed_backend={self.trainer.distributed_backend} (on SLURM)')
139+
log.info(f'distributed_backend={self.trainer.distributed_backend}')
140140
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
141141
log.info('-' * 100)
142142

0 commit comments

Comments
 (0)