Skip to content

Commit c3ad1ca

Browse files
Nicki SkafteBorda
authored andcommitted
function descriptions
1 parent 6b6b96d commit c3ad1ca

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

pytorch_lightning/metrics/functional/classification.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99

1010
def to_onehot(tensor: torch.Tensor,
1111
n_classes: Optional[int] = None) -> torch.Tensor:
12+
""" Converts a dense label tensor to one-hot format
13+
14+
Args:
15+
tensor: dense label tensor, with shape [N, d1, d2, ...]
16+
17+
n_classes: number of classes C
18+
19+
Output:
20+
A sparse label tensor with shape [N, C, d1, d2, ...]
21+
"""
1222
if n_classes is None:
1323
n_classes = int(tensor.max().detach().item() + 1)
1424
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
@@ -19,11 +29,15 @@ def to_onehot(tensor: torch.Tensor,
1929

2030

2131
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
32+
""" Converts a tensor of probabilities to a dense label tensor """
2233
return torch.argmax(tensor, dim=argmax_dim)
2334

2435

2536
def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
2637
num_classes: Optional[int]) -> int:
38+
""" Returns the number of classes for a given prediction and
39+
target tensor
40+
"""
2741
if num_classes is None:
2842
if pred.ndim > target.ndim:
2943
num_classes = pred.size(1)
@@ -36,6 +50,20 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
3650
class_index: int, argmax_dim: int = 1
3751
) -> Tuple[torch.Tensor, torch.Tensor,
3852
torch.Tensor, torch.Tensor]:
53+
""" Calculates the number of true postive, false postive, true negative
54+
and false negative for a specfic class
55+
56+
Args:
57+
pred: prediction tensor
58+
59+
target: target tensor
60+
61+
class_index: class to calculate over
62+
63+
argmax_dim: if pred is a tensor of probabilities, this indicates the
64+
axis the argmax transformation will be applied over
65+
66+
"""
3967
if pred.ndim == target.ndim + 1:
4068
pred = to_categorical(pred, argmax_dim=argmax_dim)
4169

@@ -52,6 +80,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
5280
argmax_dim: int = 1
5381
) -> Tuple[torch.Tensor, torch.Tensor,
5482
torch.Tensor, torch.Tensor]:
83+
""" Calls the stat_scores function iteratively for all classes, thus
84+
calculating the number of true postive, false postive, true negative
85+
and false negative for each class
86+
87+
Args:
88+
pred: prediction tensor
89+
90+
target: target tensor
91+
92+
class_index: class to calculate over
93+
94+
argmax_dim: if pred is a tensor of probabilities, this indicates the
95+
axis the argmax transformation will be applied over
96+
97+
"""
5598
num_classes = get_num_classes(pred=pred, target=target,
5699
num_classes=num_classes)
57100

0 commit comments

Comments
 (0)