Skip to content

Commit d78d232

Browse files
awaelchlikaushikb11
authored andcommitted
Fix ShardedDataParallel has no attribute require_backward_grad_sync (#6915)
Co-authored-by: Kaushik B <[email protected]> (cherry picked from commit fe0d088)
1 parent 8245540 commit d78d232

File tree

4 files changed

+52
-0
lines changed

4 files changed

+52
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
234234
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
235235

236236

237+
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
238+
239+
237240
## [1.2.7] - 2021-04-06
238241

239242
### Fixed

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Optional
22

3+
import torch
4+
from torch.optim import Optimizer
5+
36
from pytorch_lightning.core.lightning import LightningModule
47
from pytorch_lightning.core.optimizer import is_lightning_optimizer
58
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
@@ -19,6 +22,7 @@ def configure_ddp(self):
1922
self._model = ShardedDataParallel(
2023
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
2124
)
25+
setattr(self._model, "require_backward_grad_sync", False)
2226

2327
def _reinit_optimizers_with_oss(self):
2428
optimizers = self.lightning_module.trainer.optimizers
@@ -57,3 +61,9 @@ def _optim_state_dict(self, optimizer):
5761
@property
5862
def lightning_module(self) -> LightningModule:
5963
return unwrap_lightning_module_sharded(self._model)
64+
65+
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
66+
pass
67+
68+
def post_training_step(self):
69+
pass

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from typing import Optional
22

3+
import torch
4+
from torch.optim import Optimizer
5+
36
from pytorch_lightning.core.lightning import LightningModule
47
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
58
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
@@ -18,6 +21,7 @@ def configure_ddp(self):
1821
self._model = ShardedDataParallel(
1922
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
2023
)
24+
setattr(self._model, "require_backward_grad_sync", False)
2125

2226
def _reinit_optimizers_with_oss(self):
2327
optimizers = self.lightning_module.trainer.optimizers
@@ -52,3 +56,9 @@ def _optim_state_dict(self, optimizer):
5256
@property
5357
def lightning_module(self) -> LightningModule:
5458
return unwrap_lightning_module_sharded(self._model)
59+
60+
def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
61+
pass
62+
63+
def post_training_step(self):
64+
pass

tests/plugins/test_sharded_plugin.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,32 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir):
308308
)
309309

310310
trainer.test(model)
311+
312+
313+
class ManualBoringModel(BoringModel):
314+
315+
def __init__(self):
316+
super().__init__()
317+
self.automatic_optimization = False
318+
319+
def training_step(self, batch, batch_idx):
320+
opt = self.optimizers()
321+
opt.zero_grad()
322+
output = self(batch)
323+
loss = self.loss(batch, output)
324+
self.manual_backward(loss)
325+
opt.step()
326+
return {"loss": loss}
327+
328+
329+
@RunIf(skip_windows=True, special=True, fairscale=True, min_gpus=2)
330+
@pytest.mark.parametrize("accelerator", ["ddp_sharded", "ddp_sharded_spawn"])
331+
def test_ddp_sharded_plugin_manual_optimization(tmpdir, accelerator):
332+
model = ManualBoringModel()
333+
trainer = Trainer(
334+
default_root_dir=tmpdir,
335+
accelerator=accelerator,
336+
fast_dev_run=2,
337+
gpus=2,
338+
)
339+
trainer.fit(model)

0 commit comments

Comments
 (0)