99
1010def to_onehot (tensor : torch .Tensor ,
1111 n_classes : Optional [int ] = None ) -> torch .Tensor :
12+ """ Converts a dense label tensor to one-hot format
13+
14+ Args:
15+ tensor: dense label tensor, with shape [N, d1, d2, ...]
16+
17+ n_classes: number of classes C
18+
19+ Output:
20+ A sparse label tensor with shape [N, C, d1, d2, ...]
21+ """
1222 if n_classes is None :
1323 n_classes = int (tensor .max ().detach ().item () + 1 )
1424 dtype , device , shape = tensor .dtype , tensor .device , tensor .shape
@@ -19,11 +29,15 @@ def to_onehot(tensor: torch.Tensor,
1929
2030
2131def to_categorical (tensor : torch .Tensor , argmax_dim : int = 1 ) -> torch .Tensor :
32+ """ Converts a tensor of probabilities to a dense label tensor """
2233 return torch .argmax (tensor , dim = argmax_dim )
2334
2435
2536def get_num_classes (pred : torch .Tensor , target : torch .Tensor ,
2637 num_classes : Optional [int ]) -> int :
38+ """ Returns the number of classes for a given prediction and
39+ target tensor
40+ """
2741 if num_classes is None :
2842 if pred .ndim > target .ndim :
2943 num_classes = pred .size (1 )
@@ -36,6 +50,20 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
3650 class_index : int , argmax_dim : int = 1
3751 ) -> Tuple [torch .Tensor , torch .Tensor ,
3852 torch .Tensor , torch .Tensor ]:
53+ """ Calculates the number of true postive, false postive, true negative
54+ and false negative for a specfic class
55+
56+ Args:
57+ pred: prediction tensor
58+
59+ target: target tensor
60+
61+ class_index: class to calculate over
62+
63+ argmax_dim: if pred is a tensor of probabilities, this indicates the
64+ axis the argmax transformation will be applied over
65+
66+ """
3967 if pred .ndim == target .ndim + 1 :
4068 pred = to_categorical (pred , argmax_dim = argmax_dim )
4169
@@ -52,6 +80,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
5280 argmax_dim : int = 1
5381 ) -> Tuple [torch .Tensor , torch .Tensor ,
5482 torch .Tensor , torch .Tensor ]:
83+ """ Calls the stat_scores function iteratively for all classes, thus
84+ calculating the number of true postive, false postive, true negative
85+ and false negative for each class
86+
87+ Args:
88+ pred: prediction tensor
89+
90+ target: target tensor
91+
92+ class_index: class to calculate over
93+
94+ argmax_dim: if pred is a tensor of probabilities, this indicates the
95+ axis the argmax transformation will be applied over
96+
97+ """
5598 num_classes = get_num_classes (pred = pred , target = target ,
5699 num_classes = num_classes )
57100
0 commit comments