Skip to content

Commit 072c272

Browse files
author
SeanNaren
committed
Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization
1 parent 90b87dd commit 072c272

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from enum import Enum
1616
from typing import Any, Optional, Union
17+
from contextlib import contextmanager
1718

1819
import torch
1920
from torch.optim import Optimizer
@@ -244,6 +245,18 @@ def __setstate__(self, d):
244245
def on_save(self, checkpoint):
245246
return checkpoint
246247

248+
@contextmanager
249+
def block_ddp_plugin_sync_behaviour(self):
250+
"""
251+
Blocks ddp sync gradients behaviour on backwards pass.
252+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
253+
Returns: context manager with sync behaviour off
254+
"""
255+
if self.ddp_plugin:
256+
yield self.ddp_plugin.block_backward_sync(self.trainer.model)
257+
else:
258+
yield
259+
247260

248261
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
249262
class BackendType(Enum):

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from contextlib import contextmanager
23
from typing import Any, Dict, List, Union, Optional
34

45
import torch.distributed as torch_distrib
@@ -131,3 +132,12 @@ def get_model_from_plugin(
131132
if isinstance(model, LightningDistributedDataParallel):
132133
return model.module
133134
return model
135+
136+
@contextmanager
137+
def block_backward_sync(self, model: LightningDistributedDataParallel):
138+
"""
139+
Blocks ddp sync gradients behaviour on backwards pass.
140+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
141+
Returns: context manager with sync behaviour off
142+
"""
143+
yield model.no_sync()

pytorch_lightning/trainer/training_loop.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,25 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
668668
# -------------------
669669

670670
# perform dpp sync only when performing optimizer_step
671-
with self.block_ddp_sync_behaviour():
672-
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
671+
if self.automatic_optimization:
672+
with self.block_ddp_sync_behaviour():
673+
self.training_step_and_backward(
674+
split_batch,
675+
batch_idx,
676+
opt_idx,
677+
optimizer,
678+
self.trainer.hiddens
679+
)
680+
else:
681+
# do not block ddp gradient sync when using manual optimization
682+
# as gradients are needed within the training step
683+
self.training_step_and_backward(
684+
split_batch,
685+
batch_idx,
686+
opt_idx,
687+
optimizer,
688+
self.trainer.hiddens
689+
)
673690

674691
batch_outputs = self._process_closure_result(
675692
batch_outputs=batch_outputs,
@@ -735,8 +752,13 @@ def train_step_and_backward_closure():
735752

736753
@contextmanager
737754
def block_ddp_sync_behaviour(self):
738-
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
739-
yield self.trainer.model.no_sync()
755+
"""
756+
Blocks ddp sync gradients behaviour on backwards pass.
757+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
758+
Returns: context manager with sync behaviour off
759+
"""
760+
if self.trainer.accelerator_backend is not None:
761+
yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour()
740762
else:
741763
yield
742764

0 commit comments

Comments
 (0)