@@ -40,9 +40,7 @@ def __post_init__(self) -> None:
4040 self .extra = self ._check_extra_detach_deprecation (self .extra )
4141
4242 @classmethod
43- def from_training_step_output (
44- cls , training_step_output : Optional [STEP_OUTPUT ], normalize : int = 1
45- ) -> "ManualResult" :
43+ def from_training_step_output (cls , training_step_output : Optional [STEP_OUTPUT ]) -> "ManualResult" :
4644 extra = {}
4745 if isinstance (training_step_output , dict ):
4846 extra = {k : v for k , v in training_step_output .items () if k != "hiddens" }
@@ -55,9 +53,8 @@ def from_training_step_output(
5553 )
5654
5755 if "loss" in extra :
58- # accumulate the loss. If `accumulate_grad_batches == 1`, no effect.
5956 # we detach manually as it's expected that it will have a `grad_fn`
60- extra ["loss" ] = extra ["loss" ].detach (). div ( normalize )
57+ extra ["loss" ] = extra ["loss" ].detach ()
6158
6259 return cls (extra = extra )
6360
@@ -118,7 +115,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
118115
119116 self ._hiddens = _extract_hiddens (training_step_output , lightning_module .truncated_bptt_steps )
120117
121- result = ManualResult .from_training_step_output (training_step_output , self . trainer . accumulate_grad_batches )
118+ result = ManualResult .from_training_step_output (training_step_output )
122119
123120 if self .trainer .move_metrics_to_cpu :
124121 # hiddens and the training step output are not moved as they are not considered "metrics"
0 commit comments