Skip to content

Commit 6a0c47a

Browse files
authored
remove redundant accumulation normalization in manual optimization (#9769)
1 parent f915a8a commit 6a0c47a

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def __post_init__(self) -> None:
4040
self.extra = self._check_extra_detach_deprecation(self.extra)
4141

4242
@classmethod
43-
def from_training_step_output(
44-
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
45-
) -> "ManualResult":
43+
def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult":
4644
extra = {}
4745
if isinstance(training_step_output, dict):
4846
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
@@ -55,9 +53,8 @@ def from_training_step_output(
5553
)
5654

5755
if "loss" in extra:
58-
# accumulate the loss. If `accumulate_grad_batches == 1`, no effect.
5956
# we detach manually as it's expected that it will have a `grad_fn`
60-
extra["loss"] = extra["loss"].detach().div(normalize)
57+
extra["loss"] = extra["loss"].detach()
6158

6259
return cls(extra=extra)
6360

@@ -118,7 +115,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
118115

119116
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)
120117

121-
result = ManualResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)
118+
result = ManualResult.from_training_step_output(training_step_output)
122119

123120
if self.trainer.move_metrics_to_cpu:
124121
# hiddens and the training step output are not moved as they are not considered "metrics"

tests/loops/optimization/test_manual_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
def test_manual_result():
2424
training_step_output = {"loss": torch.tensor(25.0, requires_grad=True), "something": "jiraffe"}
25-
result = ManualResult.from_training_step_output(training_step_output, normalize=5)
25+
result = ManualResult.from_training_step_output(training_step_output)
2626
asdict = result.asdict()
2727
assert not asdict["loss"].requires_grad
28-
assert asdict["loss"] == 5
28+
assert asdict["loss"] == 25
2929
assert result.extra == asdict
3030

3131

0 commit comments

Comments
 (0)