@@ -87,7 +87,9 @@ def __init__(
8787 process_group : Optional [Any ] = None ,
8888 ):
8989 super ().__init__ (
90- compute_on_step = compute_on_step , dist_sync_on_step = dist_sync_on_step , process_group = process_group ,
90+ compute_on_step = compute_on_step ,
91+ dist_sync_on_step = dist_sync_on_step ,
92+ process_group = process_group ,
9193 )
9294
9395 self .num_classes = num_classes
@@ -98,8 +100,10 @@ def __init__(
98100
99101 allowed_average = ("micro" , "macro" , "weighted" , None )
100102 if self .average not in allowed_average :
101- raise ValueError ('Argument `average` expected to be one of the following:'
102- f' { allowed_average } but got { self .average } ' )
103+ raise ValueError (
104+ 'Argument `average` expected to be one of the following:'
105+ f' { allowed_average } but got { self .average } '
106+ )
103107
104108 self .add_state ("true_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
105109 self .add_state ("predicted_positives" , default = torch .zeros (num_classes ), dist_reduce_fx = "sum" )
@@ -125,8 +129,9 @@ def compute(self) -> torch.Tensor:
125129 """
126130 Computes fbeta over state.
127131 """
128- return _fbeta_compute (self .true_positives , self .predicted_positives ,
129- self .actual_positives , self .beta , self .average )
132+ return _fbeta_compute (
133+ self .true_positives , self .predicted_positives , self .actual_positives , self .beta , self .average
134+ )
130135
131136
132137class F1 (FBeta ):
0 commit comments