diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2c09163413981..a448f306b676a 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -240,16 +240,10 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict return step_kwargs def _track_output_for_epoch_end( - self, - outputs: List[Union[ResultCollection, Dict, Tensor]], - output: Optional[Union[ResultCollection, Dict, Tensor]], - ) -> List[Union[ResultCollection, Dict, Tensor]]: + self, outputs: List[STEP_OUTPUT], output: Optional[STEP_OUTPUT] + ) -> List[STEP_OUTPUT]: if output is not None: - if isinstance(output, ResultCollection): - output = output.detach() - if self.trainer.move_metrics_to_cpu: - output = output.cpu() - elif isinstance(output, dict): + if isinstance(output, dict): output = recursive_detach(output, to_cpu=self.trainer.move_metrics_to_cpu) elif isinstance(output, Tensor) and output.is_cuda and self.trainer.move_metrics_to_cpu: output = output.cpu()