File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed
pytorch_lightning/metrics/functional Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+
3+
4+ def reduce (to_reduce : torch .Tensor , reduction : str ) -> torch .Tensor :
5+ """
6+ reduces a given tensor by a given reduction method
7+ Parameters
8+ ----------
9+ to_reduce : torch.Tensor
10+ the tensor, which shall be reduced
11+ reduction : str
12+ a string specifying the reduction method.
13+ should be one of 'elementwise_mean' | 'none' | 'sum'
14+ Returns
15+ -------
16+ torch.Tensor
17+ reduced Tensor
18+ Raises
19+ ------
20+ ValueError
21+ if an invalid reduction parameter was given
22+ """
23+ if reduction == 'elementwise_mean' :
24+ return torch .mean (to_reduce )
25+ if reduction == 'none' :
26+ return to_reduce
27+ if reduction == 'sum' :
28+ return torch .sum (to_reduce )
29+ raise ValueError ('Reduction parameter unknown.' )
You can’t perform that action at this time.
0 commit comments