Skip to content

Commit 7f48c87

Browse files
justusschockjustusschockteddykokerananthsubtchaton
authored andcommitted
Skips DDP parameter sync (#4301)
* ddp no-sync * Update pytorch_lightning/trainer/training_loop.py Co-authored-by: ananthsub <[email protected]> * Update training_loop.py * factor __enter__ and __exit__ out to separate context manager * delete _updated_model_last_step Co-authored-by: justusschock <[email protected]> Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: chaton <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> (cherry picked from commit bbd81df)
1 parent d3a818b commit 7f48c87

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import subprocess
15+
from contextlib import contextmanager
1616
from copy import copy, deepcopy
1717

1818
import numpy as np
@@ -656,6 +656,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
656656
# checks if backward or backward + optimizer step (via closure)
657657
accumulation_done = self._accumulated_batches_reached()
658658
is_final_batch = self._num_training_batches_reached()
659+
should_accumulate = not (accumulation_done or is_final_batch)
659660

660661
# lightning module hook
661662
splits = self.tbptt_split_batch(batch)
@@ -676,13 +677,17 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
676677
model = self.trainer.get_model()
677678
model.toggle_optimizer(optimizer, opt_idx)
678679

679-
if not (accumulation_done or is_final_batch):
680+
if should_accumulate:
680681
# For gradient accumulation
681682

682683
# -------------------
683684
# calculate loss (train step + train step end)
684685
# -------------------
685-
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
686+
687+
# perform dpp sync only when performing optimizer_step
688+
with self.block_ddp_sync_behaviour():
689+
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
690+
686691
batch_outputs = self._process_closure_result(
687692
batch_callback_metrics=batch_callback_metrics,
688693
batch_log_metrics=batch_log_metrics,
@@ -696,7 +701,6 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
696701
# gradient update with accumulated gradients
697702

698703
else:
699-
700704
if self.automatic_optimization:
701705

702706
def train_step_and_backward_closure():
@@ -761,6 +765,13 @@ def train_step_and_backward_closure():
761765
)
762766
return result
763767

768+
@contextmanager
769+
def block_ddp_sync_behaviour(self):
770+
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
771+
yield from self.trainer.model.no_sync()
772+
else:
773+
yield
774+
764775
def _process_closure_result(
765776
self, batch_callback_metrics: list, batch_log_metrics: list, batch_outputs: list, opt_idx: int
766777
) -> list:

0 commit comments

Comments
 (0)