Skip to content

Commit d006c5b

Browse files
tchatonSeanNaren
authored andcommitted
[bug-fix] DDP and automatic_optimization=False (#4485)
* resolve bug * add self._running_manual_optim * update * update tests * update lightning module * resolve bug * update tests * update * resolve pep8 * update * replace by `ddp_spawn` * temporary fix * update * update * move update to training_loop * make both ddp_spawn * introduce `manual_optimizer_step` * update changelog * added changelog wrong place * add force_optimizer_step * update docstring for tests * update optimizer_step * update zero_grad * resolve flake8 * move update into manual_optimizer_step * add zero_grad * remove zero_grad tests * remove manual_backward in AMP, it doesn't help * update * loosen tests * update * update doc * add TODO * Removed unnecessary get model from native amp * Remove try except with pytest raise * Add seed, clean up imports, remove try catch to reproduce error * update code * update test * revert back * formatting * Update pytorch_lightning/core/lightning.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: SeanNaren <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 7e08b0d)
1 parent c30ba1b commit d006c5b

File tree

9 files changed

+402
-35
lines changed

9 files changed

+402
-35
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ timit_data/
3333
.Python
3434
ide_layouts/
3535
build/
36+
_build/
3637
develop-eggs/
3738
dist/
3839
downloads/

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
3434

3535

36+
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
37+
38+
3639
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
3740

3841

docs/source/lightning_module.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,12 @@ manual_backward
10091009
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
10101010
:noindex:
10111011

1012+
manual_optimizer_step
1013+
~~~~~~~~~~~~~~~~~~~~~
1014+
1015+
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
1016+
:noindex:
1017+
10121018
on_after_backward
10131019
~~~~~~~~~~~~~~~~~
10141020

docs/source/optimizers.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ to manually manage the optimization process. To do so, do the following:
3636
3737
# use self.backward which will also handle scaling the loss when using amp
3838
self.manual_backward(loss_a, opt_g)
39-
opt_g.step()
40-
opt_g.zero_grad()
39+
self.manual_optimizer_step(opt_g)
40+
4141
4242
# do anything you want
4343
loss_b = ...
4444
4545
# pass in any args that loss.backward() normally takes
4646
self.manual_backward(loss_b, opt_d, retain_graph=True)
4747
self.manual_backward(loss_b, opt_d)
48-
opt_d.step()
49-
opt_d.zero_grad()
48+
self.manual_optimizer_step(opt_d)
5049
5150
# log losses
5251
self.log('loss_a', loss_a)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,11 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
109109
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
110110
model_ref = self.trainer.get_model()
111111
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
112-
native_amp = self.trainer.amp_backend == AMPType.NATIVE
112+
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
113+
automatic_optimization = self.trainer.train_loop.automatic_optimization
113114

114115
# native amp + lbfgs is a no go right now
115-
if native_amp and is_lbfgs:
116+
if using_native_amp and is_lbfgs:
116117
raise MisconfigurationException(
117118
'native PyTorch amp and lbfgs are not compatible.'
118119
' To request, please file a Github issue in PyTorch and tag @mcarilli')
@@ -125,12 +126,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
125126
optimizer_idx=opt_idx,
126127
optimizer_closure=lambda_closure,
127128
on_tpu=False, # TPUAccelerator class sets this as True
128-
using_native_amp=native_amp,
129+
using_native_amp=using_native_amp,
129130
using_lbfgs=is_lbfgs
130131
)
131132

132133
# scale when native amp
133-
if native_amp:
134+
if automatic_optimization and using_native_amp:
134135
self.trainer.scaler.update()
135136

136137
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):

pytorch_lightning/core/lightning.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, *args, **kwargs):
111111
self._datamodule = None
112112
self._results: Optional[Result] = None
113113
self._current_fx_name = ''
114+
self._running_manual_backward = False
114115

115116
def optimizers(self):
116117
opts = self.trainer.optimizers
@@ -1068,19 +1069,65 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -
10681069
10691070
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
10701071
1072+
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
1073+
and you use `model.manual_optimizer_step(optimizer)`
1074+
10711075
Example::
10721076
10731077
def training_step(...):
10741078
(opt_a, opt_b) = self.optimizers()
10751079
loss = ...
10761080
# automatically applies scaling, etc...
10771081
self.manual_backward(loss, opt_a)
1082+
self.manual_optimizer_step(opt_a)
10781083
"""
10791084
# make sure we're using manual opt
10801085
self._verify_is_manual_optimization('manual_backward')
10811086

10821087
# backward
1088+
self._running_manual_backward = True
10831089
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
1090+
self._running_manual_backward = False
1091+
1092+
def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
1093+
"""
1094+
Call this directly from your training_step when doing optimizations manually.
1095+
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
1096+
1097+
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.
1098+
1099+
Args:
1100+
optimizer: Optimizer used to perform `.step()` call
1101+
1102+
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
1103+
and one should use accumulated gradients but not the other one.
1104+
One could put its own logic to force an optimizer step.
1105+
1106+
Example::
1107+
1108+
def training_step(...):
1109+
(opt_a, opt_b) = self.optimizers()
1110+
loss = ...
1111+
# automatically applies scaling, etc...
1112+
self.manual_backward(loss, opt_a)
1113+
# This will force an opt.step() even if accumulate_grad_batches is set.
1114+
self.manual_optimizer_step(opt_a, force_optimizer_step=True)
1115+
1116+
"""
1117+
# make sure we're using manual opt
1118+
self._verify_is_manual_optimization('manual_optimizer_step')
1119+
1120+
if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
1121+
1122+
# mock closure function as the user is responsible to call `manual_backward`
1123+
def mock_optimizer_closure():
1124+
return
1125+
1126+
self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)
1127+
1128+
# update will be called after every optimizer_step call
1129+
if self.trainer.amp_backend == AMPType.NATIVE:
1130+
self.trainer.scaler.update()
10841131

10851132
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
10861133
"""
@@ -1101,7 +1148,8 @@ def backward(self, loss, optimizer, optimizer_idx):
11011148
loss.backward()
11021149
11031150
"""
1104-
loss.backward(*args, **kwargs)
1151+
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
1152+
loss.backward(*args, **kwargs)
11051153

11061154
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11071155
"""

pytorch_lightning/trainer/training_loop.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,13 @@ def on_train_epoch_start(self, epoch):
251251
self.trainer.call_hook("on_train_epoch_start")
252252

253253
def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
254+
# hook
255+
self.trainer.call_hook('on_batch_end')
256+
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
257+
254258
# figure out what to track for epoch end
255259
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
256260

257-
# hook
258-
self.trainer.call_hook("on_batch_end")
259-
self.trainer.call_hook("on_train_batch_end", epoch_end_outputs, batch, batch_idx, dataloader_idx)
260-
261261
def reset_train_val_dataloaders(self, model):
262262
if not self.trainer.reload_dataloaders_every_epoch:
263263
self.trainer.reset_train_dataloader(model)
@@ -303,6 +303,12 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303
# when in dev debugging track the losses
304304
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
305305

306+
def _check_training_step_output(self, training_step_output):
307+
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
308+
if training_step_output.grad_fn is None:
309+
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310+
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
311+
306312
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
307313
# give the PL module a result for logging
308314
model = self.trainer.get_model()
@@ -312,6 +318,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
312318
with self.trainer.profiler.profile("model_forward"):
313319
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
314320
training_step_output = self.trainer.accelerator_backend.training_step(args)
321+
self._check_training_step_output(training_step_output)
322+
315323
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
316324

317325
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
@@ -724,6 +732,8 @@ def train_step_and_backward_closure():
724732

725733
if self._curr_step_result is None:
726734
# user decided to skip optimization
735+
# make sure to zero grad.
736+
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
727737
continue
728738

729739
batch_outputs = self._process_closure_result(
@@ -736,20 +746,11 @@ def train_step_and_backward_closure():
736746
grad_norm_dic = self._cur_grad_norm_dict
737747
self._cur_grad_norm_dict = None
738748

739-
# hook
740-
self.on_before_zero_grad(optimizer)
749+
# hook + clear gradients
750+
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
741751

742-
# clear gradients
743-
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
744-
745-
accumulated_loss = self.accumulated_loss.mean()
746-
747-
if accumulated_loss is not None:
748-
# calculate running loss for display
749-
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
750-
751-
# reset for next set of accumulated grads
752-
self.accumulated_loss.reset()
752+
# update running loss + reset accumulated loss
753+
self.update_running_loss()
753754

754755
# collapse all metrics into one dict
755756
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
@@ -950,3 +951,44 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
950951
epoch_end_outputs.append(optimizer_idx_outputs)
951952

952953
return epoch_end_outputs
954+
955+
def prepare_optimizers(self):
956+
# in manual optimization we loop over all optimizers at once
957+
optimizers = self.get_optimizers_iterable()
958+
if not self.automatic_optimization:
959+
optimizers = [optimizers[0]]
960+
return optimizers
961+
962+
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
963+
# set split_idx to trainer for tracking
964+
self.trainer.split_idx = split_idx
965+
966+
# make sure only the gradients of the current optimizer's parameters are calculated
967+
# in the training step to prevent dangling gradients in multiple-optimizer setup.
968+
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
969+
model = self.trainer.get_model()
970+
model.toggle_optimizer(optimizer, opt_idx)
971+
972+
# use to track metrics internally
973+
self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch)
974+
975+
def update_running_loss(self):
976+
accumulated_loss = self.accumulated_loss.mean()
977+
978+
if accumulated_loss is not None:
979+
# calculate running loss for display
980+
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
981+
982+
# reset for next set of accumulated grads
983+
self.accumulated_loss.reset()
984+
985+
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
986+
if self.automatic_optimization:
987+
# hook
988+
self.on_before_zero_grad(optimizer)
989+
optimizers = enumerate([optimizer])
990+
else:
991+
optimizers = self.get_optimizers_iterable()
992+
993+
for idx, optimizer in optimizers:
994+
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

0 commit comments

Comments
 (0)