@@ -37,6 +37,7 @@ class CohenKappa(Metric):
3737 while calculating the Cohen's Kappa score.
3838
3939 Usage:
40+
4041 ```python
4142 actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
4243 preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)
@@ -51,103 +52,102 @@ class CohenKappa(Metric):
5152 m.update_state(actuals, preds, sample_weight=weights)
5253 print('Final result: ', m.result().numpy()) # Result: 0.37209308
5354 ```
55+
5456 Usage with tf.keras API:
57+
5558 ```python
56- model = keras.models.Model(inputs, outputs)
59+ model = tf. keras.models.Model(inputs, outputs)
5760 model.add_metric(tfa.metrics.CohenKappa(num_classes=5)(outputs))
5861 model.compile('sgd', loss='mse')
5962 ```
60-
61- Args:
62- num_classes : Number of unique classes in your dataset
63- weightage : Weighting to be considered for calculating
64- kappa statistics. A valid value is one of
65- [None, 'linear', 'quadratic']. Defaults to None.
66-
67- Returns:
68- kappa_score : float
69- The kappa statistic, which is a number between -1 and 1. The maximum
70- value means complete agreement; zero or lower means chance agreement.
71-
72- Raises:
73- ValueError: If the value passed for `weightage` is invalid
74- i.e. not any one of [None, 'linear', 'quadratic']
7563 """
7664
7765 def __init__ (self ,
7866 num_classes ,
7967 name = 'cohen_kappa' ,
8068 weightage = None ,
81- dtype = tf .float32 ):
69+ dtype = None ):
70+ """Creates a `CohenKappa` instance.
71+
72+ Args:
73+ num_classes: Number of unique classes in your dataset.
74+ name: (Optional) String name of the metric instance.
75+ weightage: (Optional) Weighting to be considered for calculating
76+ kappa statistics. A valid value is one of
77+ [None, 'linear', 'quadratic']. Defaults to `None`.
78+ dtype: (Optional) Data type of the metric result.
79+ Defaults to `None`.
80+
81+ Raises:
82+ ValueError: If the value passed for `weightage` is invalid
83+ i.e. not any one of [None, 'linear', 'quadratic']
84+ """
8285 super (CohenKappa , self ).__init__ (name = name , dtype = dtype )
8386
8487 if weightage not in (None , 'linear' , 'quadratic' ):
8588 raise ValueError ("Unknown kappa weighting type." )
86- else :
87- self .weightage = weightage
8889
90+ self .weightage = weightage
8991 self .num_classes = num_classes
9092 self .conf_mtx = self .add_weight (
9193 'conf_mtx' ,
9294 shape = (self .num_classes , self .num_classes ),
9395 initializer = tf .keras .initializers .zeros ,
94- dtype = tf .int32 )
96+ dtype = tf .int64 )
9597
9698 def update_state (self , y_true , y_pred , sample_weight = None ):
9799 """Accumulates the confusion matrix condition statistics.
98100
99101 Args:
100- y_true : array, shape = [n_samples]
101- Labels assigned by the first annotator.
102- y_pred : array, shape = [n_samples]
103- Labels assigned by the second annotator. The kappa statistic
104- is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't
105- change the value.
106- sample_weight(optional) : for weighting labels in confusion matrix
107- Default is None. The dtype for weights should be the same
108- as the dtype for confusion matrix. For more details,
109- please check tf.math.confusion_matrix.
110-
102+ y_true: Labels assigned by the first annotator with shape
103+ `[num_samples,]`.
104+ y_pred: Labels assigned by the second annotator with shape
105+ `[num_samples,]`. The kappa statistic is symmetric,
106+ so swapping `y_true` and `y_pred` doesn't change the value.
107+ sample_weight (optional): for weighting labels in confusion matrix
108+ Defaults to `None`. The dtype for weights should be the same
109+ as the dtype for confusion matrix. For more details,
110+ please check `tf.math.confusion_matrix`.
111111
112112 Returns:
113113 Update op.
114114 """
115- y_true = tf .cast (y_true , dtype = tf .int32 )
116- y_pred = tf .cast (y_pred , dtype = tf .int32 )
115+ y_true = tf .cast (y_true , dtype = tf .int64 )
116+ y_pred = tf .cast (y_pred , dtype = tf .int64 )
117117
118118 if y_true .shape != y_pred .shape :
119119 raise ValueError (
120- "Number of samples in y_true and y_pred are different" )
120+ "Number of samples in ` y_true` and ` y_pred` are different" )
121121
122122 # compute the new values of the confusion matrix
123123 new_conf_mtx = tf .math .confusion_matrix (
124124 labels = y_true ,
125125 predictions = y_pred ,
126126 num_classes = self .num_classes ,
127- weights = sample_weight )
127+ weights = sample_weight ,
128+ dtype = tf .int64 )
128129
129130 # update the values in the original confusion matrix
130131 return self .conf_mtx .assign_add (new_conf_mtx )
131132
132133 def result (self ):
133134 nb_ratings = tf .shape (self .conf_mtx )[0 ]
134- weight_mtx = tf .ones ([nb_ratings , nb_ratings ], dtype = tf .int32 )
135+ weight_mtx = tf .ones ([nb_ratings , nb_ratings ], dtype = tf .int64 )
135136
136137 # 2. Create a weight matrix
137138 if self .weightage is None :
138- diagonal = tf .zeros ([nb_ratings ], dtype = tf .int32 )
139+ diagonal = tf .zeros ([nb_ratings ], dtype = tf .int64 )
139140 weight_mtx = tf .linalg .set_diag (weight_mtx , diagonal = diagonal )
140- weight_mtx = tf .cast (weight_mtx , dtype = tf .float32 )
141-
142141 else :
143- weight_mtx += tf .range (nb_ratings , dtype = tf .int32 )
144- weight_mtx = tf .cast (weight_mtx , dtype = tf . float32 )
142+ weight_mtx += tf .cast ( tf . range (nb_ratings ) , dtype = tf .int64 )
143+ weight_mtx = tf .cast (weight_mtx , dtype = self . dtype )
145144
146145 if self .weightage == 'linear' :
147146 weight_mtx = tf .abs (weight_mtx - tf .transpose (weight_mtx ))
148147 else :
149148 weight_mtx = tf .pow ((weight_mtx - tf .transpose (weight_mtx )), 2 )
150- weight_mtx = tf .cast (weight_mtx , dtype = tf .float32 )
149+
150+ weight_mtx = tf .cast (weight_mtx , dtype = self .dtype )
151151
152152 # 3. Get counts
153153 actual_ratings_hist = tf .reduce_sum (self .conf_mtx , axis = 1 )
@@ -161,8 +161,8 @@ def result(self):
161161 conf_mtx = self .conf_mtx / tf .reduce_sum (self .conf_mtx )
162162 out_prod = out_prod / tf .reduce_sum (out_prod )
163163
164- conf_mtx = tf .cast (conf_mtx , dtype = tf . float32 )
165- out_prod = tf .cast (out_prod , dtype = tf . float32 )
164+ conf_mtx = tf .cast (conf_mtx , dtype = self . dtype )
165+ out_prod = tf .cast (out_prod , dtype = self . dtype )
166166
167167 # 6. Calculate Kappa score
168168 numerator = tf .reduce_sum (conf_mtx * weight_mtx )
@@ -185,4 +185,6 @@ def reset_states(self):
185185
186186 for v in self .variables :
187187 K .set_value (
188- v , np .zeros ((self .num_classes , self .num_classes ), np .int32 ))
188+ v ,
189+ np .zeros ((self .num_classes , self .num_classes ),
190+ v .dtype .as_numpy_dtype ))
0 commit comments