Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


-
- Fixed reference issues during epoch end result collection ([#8621](https://github.com/PyTorchLightning/pytorch-lightning/pull/8621))


- Fixed `trainer.fit_loop.split_idx` always returning `None` ([#8601](https://github.com/PyTorchLightning/pytorch-lightning/pull/8601))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def training_step(self, batch, batch_idx):
out = self(x)
# softmax uses only a portion of the batch in the denomintaor
# softmax uses only a portion of the batch in the denominator
loss = self.softmax(out)
loss = nce_loss(loss)
return loss
Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,16 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op

loss = None
hiddens = None
results.extra = {}

# handle dict return
if isinstance(training_step_output, dict):
loss = training_step_output.pop("loss", None)
hiddens = training_step_output.pop("hiddens", None)
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
loss = training_step_output.get("loss")
hiddens = training_step_output.get("hiddens")
# detach hiddens to avoid `RuntimeError: Trying to backward through the graph a second time`
hiddens = apply_to_collection(hiddens, Tensor, lambda t: t.detach())
results.extra = training_step_output
# use the setter instead of `dict.update` because it calls `detach` on the tensor items
results.extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}

# handle scalar return
elif isinstance(training_step_output, Tensor):
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Iterator, List, Optional, Union

import torch
Expand Down Expand Up @@ -276,11 +275,7 @@ def _track_epoch_end_reduce_metrics(
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if (
isinstance(opt_outputs, list)
and len(opt_outputs) == 1
and not isinstance(opt_outputs[0], ResultCollection)
):
if isinstance(opt_outputs, list) and len(opt_outputs) == 1:
opt_outputs = opt_outputs[0]

epoch_output[opt_idx].append(opt_outputs)
Expand Down Expand Up @@ -320,9 +315,10 @@ def _prepare_outputs(
batch_outputs = [batch_outputs]

for tbptt_output in batch_outputs:
out = tbptt_output.extra
out = {}
if tbptt_output.minimize is not None:
out["loss"] = tbptt_output.minimize.detach()
out.update(tbptt_output.extra)
processed_tbptt_outputs.append(out)

# if there was only one tbptt step then we can collapse that dimension
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,30 @@ def training_step_end(self, outputs):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

trainer.fit(model)


def test_prepare_outputs(tmpdir):
"""
Test that the `extra` field of the saved `ResultCollection` objects for
`training_epoch_end` doesn't get accidentally modified by reference.
"""

class TestModel(BoringModel):
on_train_batch_end_called = 0

def on_train_batch_end(self, outputs, *args, **kwargs):
epoch_outputs = self.trainer.fit_loop.epoch_loop._epoch_output
epoch_outputs = epoch_outputs[0] # 1 optimizer
assert len(epoch_outputs) == self.on_train_batch_end_called
# `extra` should be empty for all `ResultCollection` objects
assert all(not out.extra for out in epoch_outputs)
self.on_train_batch_end_called += 1

def training_epoch_end(self, outputs) -> None:
# override so epoch outputs get stored
pass

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
trainer.fit(model)
assert model.on_train_batch_end_called == 2