Skip to content

Commit 760b83d

Browse files
author
SeanNaren
committed
Revert "[bug-fix] DDP and automatic_optimization=False (#4485)"
This reverts commit 10488dc
1 parent 224fb58 commit 760b83d

File tree

8 files changed

+31
-398
lines changed

8 files changed

+31
-398
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ timit_data/
3333
.Python
3434
ide_layouts/
3535
build/
36-
_build/
3736
develop-eggs/
3837
dist/
3938
downloads/

docs/source/lightning_module.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,12 +1009,6 @@ manual_backward
10091009
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_backward
10101010
:noindex:
10111011

1012-
manual_optimizer_step
1013-
~~~~~~~~~~~~~~~~~~~~~
1014-
1015-
.. automethod:: pytorch_lightning.core.lightning.LightningModule.manual_optimizer_step
1016-
:noindex:
1017-
10181012
on_after_backward
10191013
~~~~~~~~~~~~~~~~~
10201014

docs/source/optimizers.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,17 @@ to manually manage the optimization process. To do so, do the following:
3636
3737
# use self.backward which will also handle scaling the loss when using amp
3838
self.manual_backward(loss_a, opt_g)
39-
self.manual_optimizer_step(opt_g)
40-
39+
opt_g.step()
40+
opt_g.zero_grad()
4141
4242
# do anything you want
4343
loss_b = ...
4444
4545
# pass in any args that loss.backward() normally takes
4646
self.manual_backward(loss_b, opt_d, retain_graph=True)
4747
self.manual_backward(loss_b, opt_d)
48-
self.manual_optimizer_step(opt_d)
48+
opt_d.step()
49+
opt_d.zero_grad()
4950
5051
# log losses
5152
self.log('loss_a', loss_a)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,10 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
109109
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
110110
model_ref = self.trainer.get_model()
111111
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
112-
using_native_amp = self.trainer.amp_backend == AMPType.NATIVE
113-
automatic_optimization = self.trainer.train_loop.automatic_optimization
112+
native_amp = self.trainer.amp_backend == AMPType.NATIVE
114113

115114
# native amp + lbfgs is a no go right now
116-
if using_native_amp and is_lbfgs:
115+
if native_amp and is_lbfgs:
117116
raise MisconfigurationException(
118117
'native PyTorch amp and lbfgs are not compatible.'
119118
' To request, please file a Github issue in PyTorch and tag @mcarilli')
@@ -126,12 +125,12 @@ def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
126125
optimizer_idx=opt_idx,
127126
optimizer_closure=lambda_closure,
128127
on_tpu=False, # TPUAccelerator class sets this as True
129-
using_native_amp=using_native_amp,
128+
using_native_amp=native_amp,
130129
using_lbfgs=is_lbfgs
131130
)
132131

133132
# scale when native amp
134-
if automatic_optimization and using_native_amp:
133+
if native_amp:
135134
self.trainer.scaler.update()
136135

137136
def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,6 @@ def __init__(self, *args, **kwargs):
111111
self._datamodule = None
112112
self._results: Optional[Result] = None
113113
self._current_fx_name = ''
114-
self._running_manual_backward = False
115114

116115
def optimizers(self):
117116
opts = self.trainer.optimizers
@@ -1071,65 +1070,19 @@ def manual_backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -
10711070
10721071
.. tip:: In manual mode we still automatically clip grads if Trainer(gradient_clip_val=x) is set
10731072
1074-
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set
1075-
and you use `model.manual_optimizer_step(optimizer)`
1076-
10771073
Example::
10781074
10791075
def training_step(...):
10801076
(opt_a, opt_b) = self.optimizers()
10811077
loss = ...
10821078
# automatically applies scaling, etc...
10831079
self.manual_backward(loss, opt_a)
1084-
self.manual_optimizer_step(opt_a)
10851080
"""
10861081
# make sure we're using manual opt
10871082
self._verify_is_manual_optimization('manual_backward')
10881083

10891084
# backward
1090-
self._running_manual_backward = True
10911085
self.trainer.train_loop.backward(loss, optimizer, -1, *args, **kwargs)
1092-
self._running_manual_backward = False
1093-
1094-
def manual_optimizer_step(self, optimizer: Optimizer, force_optimizer_step:bool = False) -> None:
1095-
"""
1096-
Call this directly from your training_step when doing optimizations manually.
1097-
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
1098-
1099-
.. tip:: In manual mode we still automatically accumulate grad over batches if Trainer(accumulate_grad_batches=x) is set.
1100-
1101-
Args:
1102-
optimizer: Optimizer used to perform `.step()` call
1103-
1104-
force_optimizer_step: Whether to force an optimizer step. Could be useful when having 2 optimizers
1105-
and one should use accumulated gradients but not the other one.
1106-
One could put its own logic to force an optimizer step.
1107-
1108-
Example::
1109-
1110-
def training_step(...):
1111-
(opt_a, opt_b) = self.optimizers()
1112-
loss = ...
1113-
# automatically applies scaling, etc...
1114-
self.manual_backward(loss, opt_a)
1115-
# This will force an opt.step() even if accumulate_grad_batches is set.
1116-
self.manual_optimizer_step(opt_a, force_optimizer_step=True)
1117-
1118-
"""
1119-
# make sure we're using manual opt
1120-
self._verify_is_manual_optimization('manual_optimizer_step')
1121-
1122-
if not self.trainer.train_loop.should_accumulate() or force_optimizer_step:
1123-
1124-
# mock closure function as the user is responsible to call `manual_backward`
1125-
def mock_optimizer_closure():
1126-
return
1127-
1128-
self.trainer.train_loop.optimizer_step(optimizer, None, self.trainer.batch_idx, mock_optimizer_closure)
1129-
1130-
# update will be called after every optimizer_step call
1131-
if self.trainer.amp_backend == AMPType.NATIVE:
1132-
self.trainer.scaler.update()
11331086

11341087
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
11351088
"""
@@ -1150,8 +1103,7 @@ def backward(self, loss, optimizer, optimizer_idx):
11501103
loss.backward()
11511104
11521105
"""
1153-
if self.trainer.train_loop.automatic_optimization or self._running_manual_backward:
1154-
loss.backward(*args, **kwargs)
1106+
loss.backward(*args, **kwargs)
11551107

11561108
def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11571109
"""

pytorch_lightning/trainer/training_loop.py

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,6 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
303303
# when in dev debugging track the losses
304304
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
305305

306-
def _check_training_step_output(self, training_step_output):
307-
if isinstance(training_step_output, torch.Tensor) and not self.automatic_optimization:
308-
if training_step_output.grad_fn is None:
309-
# TODO: Find why - RuntimeError: Expected to mark a variable ready only once ...
310-
raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor")
311-
312306
def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
313307
# give the PL module a result for logging
314308
model = self.trainer.get_model()
@@ -318,8 +312,6 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
318312
with self.trainer.profiler.profile("model_forward"):
319313
args = self.build_train_args(split_batch, batch_idx, opt_idx, hiddens)
320314
training_step_output = self.trainer.accelerator_backend.training_step(args)
321-
self._check_training_step_output(training_step_output)
322-
323315
training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
324316

325317
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
@@ -620,9 +612,6 @@ def run_training_epoch(self):
620612
# progress global step according to grads progress
621613
self.increment_accumulated_grad_global_step()
622614

623-
# epoch end hook
624-
self.run_on_epoch_end_hook(epoch_output)
625-
626615
# log epoch metrics
627616
self.trainer.logger_connector.log_train_epoch_end_metrics(
628617
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
@@ -734,8 +723,6 @@ def train_step_and_backward_closure():
734723

735724
if self._curr_step_result is None:
736725
# user decided to skip optimization
737-
# make sure to zero grad.
738-
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
739726
continue
740727

741728
batch_outputs = self._process_closure_result(
@@ -748,11 +735,20 @@ def train_step_and_backward_closure():
748735
grad_norm_dic = self._cur_grad_norm_dict
749736
self._cur_grad_norm_dict = None
750737

751-
# hook + clear gradients
752-
self.zero_grad_handler(batch_idx, optimizer, opt_idx)
738+
# hook
739+
self.on_before_zero_grad(optimizer)
740+
741+
# clear gradients
742+
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
753743

754-
# update running loss + reset accumulated loss
755-
self.update_running_loss()
744+
accumulated_loss = self.accumulated_loss.mean()
745+
746+
if accumulated_loss is not None:
747+
# calculate running loss for display
748+
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
749+
750+
# reset for next set of accumulated grads
751+
self.accumulated_loss.reset()
756752

757753
# collapse all metrics into one dict
758754
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
@@ -953,44 +949,3 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
953949
epoch_end_outputs.append(optimizer_idx_outputs)
954950

955951
return epoch_end_outputs
956-
957-
def prepare_optimizers(self):
958-
# in manual optimization we loop over all optimizers at once
959-
optimizers = self.get_optimizers_iterable()
960-
if not self.automatic_optimization:
961-
optimizers = [optimizers[0]]
962-
return optimizers
963-
964-
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
965-
# set split_idx to trainer for tracking
966-
self.trainer.split_idx = split_idx
967-
968-
# make sure only the gradients of the current optimizer's parameters are calculated
969-
# in the training step to prevent dangling gradients in multiple-optimizer setup.
970-
if self.automatic_optimization and len(self.trainer.optimizers) > 1:
971-
model = self.trainer.get_model()
972-
model.toggle_optimizer(optimizer, opt_idx)
973-
974-
# use to track metrics internally
975-
self.trainer.logger_connector.on_train_split_start(split_idx, opt_idx, split_batch)
976-
977-
def update_running_loss(self):
978-
accumulated_loss = self.accumulated_loss.mean()
979-
980-
if accumulated_loss is not None:
981-
# calculate running loss for display
982-
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
983-
984-
# reset for next set of accumulated grads
985-
self.accumulated_loss.reset()
986-
987-
def zero_grad_handler(self, batch_idx, optimizer, opt_idx):
988-
if self.automatic_optimization:
989-
# hook
990-
self.on_before_zero_grad(optimizer)
991-
optimizers = enumerate([optimizer])
992-
else:
993-
optimizers = self.get_optimizers_iterable()
994-
995-
for idx, optimizer in optimizers:
996-
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)

0 commit comments

Comments
 (0)