Skip to content

Commit b193059

Browse files
committed
add function to reduce tensors (similar to reduction in torch.nn)
1 parent 79f0731 commit b193059

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
3+
4+
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
5+
"""
6+
reduces a given tensor by a given reduction method
7+
Parameters
8+
----------
9+
to_reduce : torch.Tensor
10+
the tensor, which shall be reduced
11+
reduction : str
12+
a string specifying the reduction method.
13+
should be one of 'elementwise_mean' | 'none' | 'sum'
14+
Returns
15+
-------
16+
torch.Tensor
17+
reduced Tensor
18+
Raises
19+
------
20+
ValueError
21+
if an invalid reduction parameter was given
22+
"""
23+
if reduction == 'elementwise_mean':
24+
return torch.mean(to_reduce)
25+
if reduction == 'none':
26+
return to_reduce
27+
if reduction == 'sum':
28+
return torch.sum(to_reduce)
29+
raise ValueError('Reduction parameter unknown.')

0 commit comments

Comments
 (0)