Skip to content

Commit cc6284a

Browse files
committed
Move block_backward_sync from ParallelPlugin to DDPPlugins
1 parent de57fef commit cc6284a

File tree

5 files changed

+36
-18
lines changed

5 files changed

+36
-18
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
205205
- Removed `InterBatchProcessor` in favor of `DataLoaderIterDataFetcher` ([#9052](https://github.com/PyTorchLightning/pytorch-lightning/pull/9052))
206206

207207

208+
- Removed `block_backward_sync` from `ParallelPlugin` and added to `DDPPlugin` and `DDPSpawnPlugin` ([#9101](https://github.com/PyTorchLightning/pytorch-lightning/pull/9101))
209+
210+
208211
### Fixed
209212

210213
- Fixed save/load/resume from checkpoint for DeepSpeed Plugin (

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_process_training_step_output,
3232
check_finite_loss,
3333
)
34-
from pytorch_lightning.plugins import ParallelPlugin
34+
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
3535
from pytorch_lightning.trainer.progress import OptimizationProgress
3636
from pytorch_lightning.trainer.supporters import TensorRunningAccum
3737
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
@@ -430,9 +430,10 @@ def block_ddp_sync_behaviour(self, should_block_sync: bool = False) -> Generator
430430
Returns:
431431
context manager with sync behaviour off
432432
"""
433-
if isinstance(self.trainer.training_type_plugin, ParallelPlugin) and (
434-
self.trainer.lightning_module.automatic_optimization or should_block_sync
435-
):
433+
if (
434+
isinstance(self.trainer.training_type_plugin, DDPPlugin)
435+
or isinstance(self.trainer.training_type_plugin, DDPPlugin)
436+
) and (self.trainer.lightning_module.automatic_optimization or should_block_sync):
436437
with self.trainer.training_type_plugin.block_backward_sync():
437438
yield None
438439
else:

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import sys
2020
import tempfile
2121
import time
22+
from contextlib import contextmanager
2223
from pathlib import Path
2324
from time import sleep
2425
from typing import Any, Dict, List, Optional, Union
@@ -442,3 +443,16 @@ def reconciliate_processes(self, trace: str):
442443
os.kill(pid, signal.SIGKILL)
443444
shutil.rmtree(sync_dir)
444445
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
446+
447+
@contextmanager
448+
def block_backward_sync(self):
449+
"""
450+
Blocks ddp sync gradients behaviour on backwards pass.
451+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
452+
Returns: context manager with sync behaviour off
453+
"""
454+
if isinstance(self.model, DistributedDataParallel):
455+
with self.model.no_sync():
456+
yield None
457+
else:
458+
yield None

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import logging
1515
import os
1616
import re
17+
from contextlib import contextmanager
1718
from multiprocessing.queues import SimpleQueue
1819
from typing import Any, Dict, List, Optional, Union
1920

@@ -364,3 +365,16 @@ def register_plugins(cls, plugin_registry: Dict) -> None:
364365
description="DDPSpawn Plugin with `find_unused_parameters` as False",
365366
find_unused_parameters=False,
366367
)
368+
369+
@contextmanager
370+
def block_backward_sync(self):
371+
"""
372+
Blocks ddp sync gradients behaviour on backwards pass.
373+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
374+
Returns: context manager with sync behaviour off
375+
"""
376+
if isinstance(self.model, DistributedDataParallel):
377+
with self.model.no_sync():
378+
yield None
379+
else:
380+
yield None

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515
from abc import ABC, abstractmethod
16-
from contextlib import contextmanager
1716
from typing import Any, List, Optional
1817

1918
import torch
@@ -121,19 +120,6 @@ def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule
121120
"""
122121
return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
123122

124-
@contextmanager
125-
def block_backward_sync(self):
126-
"""
127-
Blocks ddp sync gradients behaviour on backwards pass.
128-
This is useful for skipping sync when accumulating gradients, reducing communication overhead
129-
Returns: context manager with sync behaviour off
130-
"""
131-
if isinstance(self.model, DistributedDataParallel):
132-
with self.model.no_sync():
133-
yield None
134-
else:
135-
yield None
136-
137123
def teardown(self) -> None:
138124
# Un-reference the wrapper if any was used.
139125
# todo (tchaton): Add support for all plugins.

0 commit comments

Comments
 (0)