Skip to content

Commit 2393474

Browse files
tchatonSeanNarenSeanNarenBorda
authored
[hotfix] ddp + manual_optimisation (#4976)
* Rely on ddp plugin for blocking sync behaviour, and skip if we're using manual optimization * debug * Revert "debug" This reverts commit ccca6b6 * Expose manual reduce for automatic optimization * Add input arguments * Enable parity test * clean imports * Expose hook after to ensure we reset * Fix naming * add * fix test * resolve on comments * typo * Update tests/trainer/optimization/test_manual_optimization.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/trainer/optimization/test_manual_optimization.py Co-authored-by: Jirka Borovec <[email protected]> * update on comments * resolve comments Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 68ba493 commit 2393474

File tree

8 files changed

+196
-17
lines changed

8 files changed

+196
-17
lines changed

benchmarks/test_sharded_parity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
148148
)
149149

150150

151-
@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.")
152151
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
153152
@pytest.mark.skipif(platform.system() == "Windows",
154153
reason="Distributed training is not supported on Windows")
@@ -182,16 +181,17 @@ def training_step(self, batch, batch_idx, optimizer_idx):
182181
loss_1 = self.step(batch)
183182

184183
self.manual_backward(loss_1, opt_a)
185-
self.manual_optimizer_step(opt_a)
184+
opt_a.step()
186185

187186
# fake discriminator
188187
loss_2 = self.step(batch[0])
189188

190189
# ensure we forward the correct params to the optimizer
191190
# without retain_graph we can't do multiple backward passes
192191
self.manual_backward(loss_2, opt_b, retain_graph=True)
193-
self.manual_backward(loss_2, opt_a, retain_graph=True)
194-
self.manual_optimizer_step(opt_b)
192+
# todo: understand why synchronization breaks there.
193+
# self.manual_backward(loss_2, opt_a, retain_graph=True)
194+
opt_b.step()
195195

196196
assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0)
197197

pytorch_lightning/accelerators/accelerator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from contextlib import contextmanager
1516
from enum import Enum
1617
from typing import Any, Optional, Union
1718

@@ -86,6 +87,12 @@ def process_dataloader(self, dataloader):
8687
return dataloader
8788

8889
def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
90+
automatic_optimization = self.trainer.train_loop.automatic_optimization
91+
92+
if not automatic_optimization and self.ddp_plugin is not None:
93+
# Manually prepare for reduce as user calling backwards manually
94+
self.ddp_plugin.on_before_manual_backward(self.trainer.model, closure_loss)
95+
8996
if self.trainer.precision == 16:
9097
closure_loss = self.trainer.precision_connector.backend.backward(
9198
closure_loss, optimizer, opt_idx, *args, **kwargs
@@ -97,6 +104,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
97104

98105
# once backward has been applied, release graph
99106
closure_loss = closure_loss.detach()
107+
108+
if not automatic_optimization and self.ddp_plugin is not None:
109+
# Manually prepare for reduce as user calling backwards manually
110+
self.ddp_plugin.on_after_manual_backward(self.trainer.model)
100111
return closure_loss
101112

102113
def clip_gradients(self, optimizer, clip_val=None):
@@ -211,6 +222,16 @@ def __setstate__(self, d):
211222
def on_save(self, checkpoint):
212223
return checkpoint
213224

225+
@contextmanager
226+
def block_ddp_plugin_sync_behaviour(self):
227+
"""
228+
Blocks ddp sync gradients behaviour on backwards pass.
229+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
230+
Returns: context manager with sync behaviour off
231+
"""
232+
cm = self.ddp_plugin.block_backward_sync(self.trainer.model) if self.ddp_plugin else None
233+
yield cm
234+
214235

215236
# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
216237
class BackendType(Enum):

pytorch_lightning/overrides/data_parallel.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def parallel_apply(self, replicas, inputs, kwargs):
161161

162162
def forward(self, *inputs, **kwargs): # pragma: no-cover
163163
self._sync_params()
164+
self.reducer_reset_hooks()
164165
fx_called: str = ''
165166

166167
if self.device_ids:
@@ -194,6 +195,15 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
194195
else:
195196
output = self.module.validation_step(*inputs, **kwargs)
196197

198+
if not self._reducer_prepared_for_backwards:
199+
self.reducer_prepare_for_backwards(output)
200+
201+
if output is None:
202+
warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
203+
return output
204+
205+
def reducer_prepare_for_backwards(self, output):
206+
self._reducer_prepared_for_backwards = True
197207
if torch.is_grad_enabled():
198208
# We'll return the output object verbatim since it is a freeform
199209
# object. We need to find any tensors in this object, though,
@@ -205,9 +215,8 @@ def forward(self, *inputs, **kwargs): # pragma: no-cover
205215
else:
206216
self.reducer.prepare_for_backward([])
207217

208-
if output is None:
209-
warn_missing_output(f'{fx_called} returned None. Did you forget to re')
210-
return output
218+
def reducer_reset_hooks(self):
219+
self._reducer_prepared_for_backwards = False
211220

212221

213222
def warn_missing_output(fx_called):

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Union
2+
from contextlib import contextmanager
3+
from typing import Any, Dict, List, Union, Optional
34

45
import torch.distributed as torch_distrib
56
from torch.optim import Optimizer
@@ -132,3 +133,18 @@ def get_model_from_plugin(
132133
if isinstance(model, LightningDistributedDataParallel):
133134
return model.module
134135
return model
136+
137+
@contextmanager
138+
def block_backward_sync(self, model: LightningDistributedDataParallel):
139+
"""
140+
Blocks ddp sync gradients behaviour on backwards pass.
141+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
142+
Returns: context manager with sync behaviour off
143+
"""
144+
yield model.no_sync()
145+
146+
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
147+
model.reducer_prepare_for_backwards(output)
148+
149+
def on_after_manual_backward(self, model: LightningDistributedDataParallel):
150+
model.reducer_reset_hooks()

pytorch_lightning/plugins/sharded_plugin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List, Optional, Union
14+
from typing import List, Optional, Union, Any
1515

1616
from pytorch_lightning.core.lightning import LightningModule
1717
from pytorch_lightning.core.optimizer import is_lightning_optimizer
@@ -94,3 +94,9 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:
9494
if amp_backend == AMPType.NATIVE:
9595
return [ShardedNativeAMPPlugin(trainer=trainer)]
9696
return []
97+
98+
def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
99+
pass
100+
101+
def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
102+
pass

pytorch_lightning/trainer/training_loop.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,9 +679,15 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
679679
# calculate loss (train step + train step end)
680680
# -------------------
681681

682-
# perform dpp sync only when performing optimizer_step
682+
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
683+
# automatic_optimization=False: don't block synchronization here
683684
with self.block_ddp_sync_behaviour():
684-
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens)
685+
self.training_step_and_backward(
686+
split_batch,
687+
batch_idx,
688+
opt_idx,
689+
optimizer,
690+
self.trainer.hiddens)
685691

686692
batch_outputs = self._process_closure_result(
687693
batch_outputs=batch_outputs,
@@ -743,10 +749,22 @@ def train_step_and_backward_closure():
743749

744750
@contextmanager
745751
def block_ddp_sync_behaviour(self):
746-
if isinstance(self.trainer.model, torch.nn.parallel.DistributedDataParallel):
747-
yield self.trainer.model.no_sync()
752+
"""
753+
automatic_optimization = True
754+
Blocks ddp sync gradients behaviour on backwards pass.
755+
This is useful for skipping sync when accumulating gradients, reducing communication overhead
756+
757+
automatic_optimization = False
758+
do not block ddp gradient sync when using manual optimization
759+
as gradients are needed within the training step
760+
761+
Returns: context manager with sync behaviour off
762+
763+
"""
764+
if self.trainer.accelerator_backend is not None and self.automatic_optimization:
765+
yield self.trainer.accelerator_backend.block_ddp_plugin_sync_behaviour()
748766
else:
749-
yield
767+
yield None
750768

751769
def _process_closure_result(
752770
self, batch_outputs: list, opt_idx: int

tests/special_tests.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
export PL_RUNNING_SPECIAL_TESTS=1
1614
# Running special tests
15+
export PL_RUNNING_SPECIAL_TESTS=1
1716
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
17+
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import pytest
2020
import torch
21+
import torch.distributed as torch_distrib
2122
import torch.nn.functional as F
2223

2324
from pytorch_lightning import Trainer, seed_everything
@@ -862,7 +863,7 @@ def dis_closure():
862863
self.manual_backward(loss_dis, opt_dis)
863864

864865
# this will accumulate gradients for 2 batches and then call opt_gen.step()
865-
opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd')
866+
opt_gen.step(closure=gen_closure, make_optimizer_step=(batch_idx % 2 == 0), optim='sgd')
866867

867868
# update discriminator every 4 baches
868869
# therefore, no gradient accumulation for discriminator
@@ -904,6 +905,114 @@ def configure_optimizers(self):
904905
mock_adam_step.assert_has_calls(expected_calls)
905906

906907

908+
@patch("torch.optim.Adam.step")
909+
@patch("torch.optim.SGD.step")
910+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
911+
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
912+
def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir):
913+
"""
914+
Tests that `step` works with optimizer_closure and different accumulated_gradient frequency
915+
"""
916+
os.environ['PL_DEV_DEBUG'] = '1'
917+
918+
class TestModel(BoringModel):
919+
920+
def loss_ones(self, batch, prediction):
921+
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
922+
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
923+
924+
def loss_zeros(self, batch, prediction):
925+
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
926+
return torch.nn.functional.mse_loss(prediction, torch.zeros_like(prediction))
927+
928+
def manual_sync_grad(self) -> bool:
929+
torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False)
930+
return True
931+
932+
def training_step(self, batch, batch_idx, optimizer_idx):
933+
934+
# emulate gans training
935+
opt_gen, opt_dis = self.optimizers()
936+
937+
# Note: Be careful, don't log on the same key in self.log in both closure
938+
# as they will be aggregated together on epoch_end
939+
940+
world_size = torch_distrib.get_world_size(torch_distrib.group.WORLD)
941+
assert world_size == 2
942+
943+
def compute_loss():
944+
x = batch[0]
945+
x = F.dropout(x, 0.1)
946+
predictions = self(x)
947+
predictions = F.dropout(predictions, 0.1)
948+
loss_ones = self.loss_ones(None, predictions)
949+
loss_zeros = self.loss_zeros(None, predictions)
950+
return loss_ones, loss_zeros
951+
952+
def make_manual_backward(loss, opt, retain_graph=False):
953+
self.manual_backward(loss, opt, retain_graph=retain_graph)
954+
grad_clone = self.layer.weight.grad.clone()
955+
assert self.manual_sync_grad()
956+
self.layer.weight.grad /= world_size
957+
assert torch.equal(self.layer.weight.grad, grad_clone)
958+
959+
def gen_closure():
960+
loss_ones_gen, loss_zeros = compute_loss()
961+
make_manual_backward(loss_ones_gen, opt_gen, retain_graph=True)
962+
make_manual_backward(loss_ones_gen, opt_gen)
963+
964+
def dis_closure():
965+
loss_ones_gen, loss_zeros = compute_loss()
966+
make_manual_backward(loss_ones_gen, opt_dis, retain_graph=True)
967+
make_manual_backward(loss_ones_gen, opt_dis)
968+
969+
# this will accumulate gradients for 2 batches and then call opt_gen.step()
970+
opt_gen.step(closure=gen_closure, make_optimizer_step=batch_idx % 2 == 0, optim='sgd')
971+
972+
# update discriminator every 4 baches
973+
# therefore, no gradient accumulation for discriminator
974+
if batch_idx % 4 == 0 :
975+
# Note: Set make_optimizer_step to True or it will use by default
976+
# Trainer(accumulate_grad_batches=x)
977+
opt_dis.step(closure=dis_closure, make_optimizer_step=True, optim='adam')
978+
979+
def training_epoch_end(self, outputs) -> None:
980+
# outputs should be an array with an entry per optimizer
981+
assert len(outputs) == 2
982+
983+
def configure_optimizers(self):
984+
optimizer_gen = torch.optim.SGD(self.layer.parameters(), lr=0.1)
985+
optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001)
986+
return [optimizer_gen, optimizer_dis]
987+
988+
seed_everything(42)
989+
990+
model = TestModel()
991+
model.val_dataloader = None
992+
model.training_epoch_end = None
993+
994+
limit_train_batches = 8
995+
trainer = Trainer(
996+
automatic_optimization=False,
997+
default_root_dir=tmpdir,
998+
limit_train_batches=limit_train_batches,
999+
limit_val_batches=2,
1000+
max_epochs=1,
1001+
log_every_n_steps=1,
1002+
accumulate_grad_batches=2,
1003+
enable_pl_optimizer=True,
1004+
gpus=2,
1005+
accelerator="ddp",
1006+
)
1007+
1008+
trainer.fit(model)
1009+
expected_calls = [call(closure=ANY, optim='sgd')] * 4
1010+
mock_sgd_step.assert_has_calls(expected_calls)
1011+
1012+
expected_calls = [call(closure=ANY, optim='adam')] * 2
1013+
mock_adam_step.assert_has_calls(expected_calls)
1014+
1015+
9071016
def test_step_with_misconfiguraiton_error_when_overriding_optimizer_zero_grad(tmpdir):
9081017
"""
9091018
Tests that `optimizer_zero_grad` in manual_optimization triggers a MisconfigurationException

0 commit comments

Comments
 (0)