Skip to content

Commit 4f12334

Browse files
carmoccaawaelchli
authored andcommitted
Fix reference issues during epoch end result collection (#8621)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 3217863 commit 4f12334

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ def training_step(self, batch, batch_idx):
750750
751751
out = self(x)
752752
753-
# softmax uses only a portion of the batch in the denomintaor
753+
# softmax uses only a portion of the batch in the denominator
754754
loss = self.softmax(out)
755755
loss = nce_loss(loss)
756756
return loss

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,16 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op
339339

340340
loss = None
341341
hiddens = None
342-
results.extra = {}
343342

344343
# handle dict return
345344
if isinstance(training_step_output, dict):
346-
loss = training_step_output.pop("loss", None)
347-
hiddens = training_step_output.pop("hiddens", None)
345+
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
346+
loss = training_step_output.get("loss")
347+
hiddens = training_step_output.get("hiddens")
348348
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
349349
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
350-
results.extra = training_step_output
350+
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
351+
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
351352

352353
# handle scalar return
353354
elif isinstance(training_step_output, Tensor):

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from typing import Any, Dict, Iterator, List, Optional, Union
1615

1716
import torch
@@ -314,11 +313,7 @@ def _track_epoch_end_reduce_metrics(
314313
# track the outputs to reduce at the end of the epoch
315314
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
316315
# with 1 step (no tbptt) don't use a sequence at epoch end
317-
if (
318-
isinstance(opt_outputs, list)
319-
and len(opt_outputs) == 1
320-
and not isinstance(opt_outputs[0], ResultCollection)
321-
):
316+
if isinstance(opt_outputs, list) and len(opt_outputs) == 1:
322317
opt_outputs = opt_outputs[0]
323318

324319
epoch_output[opt_idx].append(opt_outputs)
@@ -376,9 +371,10 @@ def _prepare_outputs(
376371
batch_outputs = [batch_outputs]
377372

378373
for tbptt_output in batch_outputs:
379-
out = tbptt_output.extra
374+
out = {}
380375
if tbptt_output.minimize is not None:
381376
out["loss"] = tbptt_output.minimize.detach()
377+
out.update(tbptt_output.extra)
382378
processed_tbptt_outputs.append(out)
383379

384380
# if there was only one tbptt step then we can collapse that dimension

tests/trainer/loops/test_training_loop.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,30 @@ def training_step_end(self, outputs):
163163
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
164164

165165
trainer.fit(model)
166+
167+
168+
def test_prepare_outputs(tmpdir):
169+
"""
170+
Test that the `extra` field of the saved `ResultCollection` objects for
171+
`training_epoch_end` doesn't get accidentally modified by reference.
172+
"""
173+
174+
class TestModel(BoringModel):
175+
on_train_batch_end_called = 0
176+
177+
def on_train_batch_end(self, outputs, *args, **kwargs):
178+
epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output
179+
epoch_outputs = epoch_outputs[0] # 1 optimizer
180+
assert len(epoch_outputs) == self.on_train_batch_end_called
181+
# `extra` should be empty for all `ResultCollection` objects
182+
assert all(not out.extra for out in epoch_outputs)
183+
self.on_train_batch_end_called += 1
184+
185+
def training_epoch_end(self, outputs) -> None:
186+
# override so epoch outputs get stored
187+
pass
188+
189+
model = TestModel()
190+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
191+
trainer.fit(model)
192+
assert model.on_train_batch_end_called == 2

0 commit comments

Comments
 (0)