Skip to content

Commit 4147791

Browse files
WindQAQseanpmorgan
authored andcommitted
Fix overflow of CohenKappa (#675)
* fix overflow of int32
1 parent 2b070a1 commit 4147791

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

tensorflow_addons/metrics/cohens_kappa.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CohenKappa(Metric):
3737
while calculating the Cohen's Kappa score.
3838
3939
Usage:
40+
4041
```python
4142
actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
4243
preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)
@@ -51,103 +52,102 @@ class CohenKappa(Metric):
5152
m.update_state(actuals, preds, sample_weight=weights)
5253
print('Final result: ', m.result().numpy()) # Result: 0.37209308
5354
```
55+
5456
Usage with tf.keras API:
57+
5558
```python
56-
model = keras.models.Model(inputs, outputs)
59+
model = tf.keras.models.Model(inputs, outputs)
5760
model.add_metric(tfa.metrics.CohenKappa(num_classes=5)(outputs))
5861
model.compile('sgd', loss='mse')
5962
```
60-
61-
Args:
62-
num_classes : Number of unique classes in your dataset
63-
weightage : Weighting to be considered for calculating
64-
kappa statistics. A valid value is one of
65-
[None, 'linear', 'quadratic']. Defaults to None.
66-
67-
Returns:
68-
kappa_score : float
69-
The kappa statistic, which is a number between -1 and 1. The maximum
70-
value means complete agreement; zero or lower means chance agreement.
71-
72-
Raises:
73-
ValueError: If the value passed for `weightage` is invalid
74-
i.e. not any one of [None, 'linear', 'quadratic']
7563
"""
7664

7765
def __init__(self,
7866
num_classes,
7967
name='cohen_kappa',
8068
weightage=None,
81-
dtype=tf.float32):
69+
dtype=None):
70+
"""Creates a `CohenKappa` instance.
71+
72+
Args:
73+
num_classes: Number of unique classes in your dataset.
74+
name: (Optional) String name of the metric instance.
75+
weightage: (Optional) Weighting to be considered for calculating
76+
kappa statistics. A valid value is one of
77+
[None, 'linear', 'quadratic']. Defaults to `None`.
78+
dtype: (Optional) Data type of the metric result.
79+
Defaults to `None`.
80+
81+
Raises:
82+
ValueError: If the value passed for `weightage` is invalid
83+
i.e. not any one of [None, 'linear', 'quadratic']
84+
"""
8285
super(CohenKappa, self).__init__(name=name, dtype=dtype)
8386

8487
if weightage not in (None, 'linear', 'quadratic'):
8588
raise ValueError("Unknown kappa weighting type.")
86-
else:
87-
self.weightage = weightage
8889

90+
self.weightage = weightage
8991
self.num_classes = num_classes
9092
self.conf_mtx = self.add_weight(
9193
'conf_mtx',
9294
shape=(self.num_classes, self.num_classes),
9395
initializer=tf.keras.initializers.zeros,
94-
dtype=tf.int32)
96+
dtype=tf.int64)
9597

9698
def update_state(self, y_true, y_pred, sample_weight=None):
9799
"""Accumulates the confusion matrix condition statistics.
98100
99101
Args:
100-
y_true : array, shape = [n_samples]
101-
Labels assigned by the first annotator.
102-
y_pred : array, shape = [n_samples]
103-
Labels assigned by the second annotator. The kappa statistic
104-
is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't
105-
change the value.
106-
sample_weight(optional) : for weighting labels in confusion matrix
107-
Default is None. The dtype for weights should be the same
108-
as the dtype for confusion matrix. For more details,
109-
please check tf.math.confusion_matrix.
110-
102+
y_true: Labels assigned by the first annotator with shape
103+
`[num_samples,]`.
104+
y_pred: Labels assigned by the second annotator with shape
105+
`[num_samples,]`. The kappa statistic is symmetric,
106+
so swapping `y_true` and `y_pred` doesn't change the value.
107+
sample_weight (optional): for weighting labels in confusion matrix
108+
Defaults to `None`. The dtype for weights should be the same
109+
as the dtype for confusion matrix. For more details,
110+
please check `tf.math.confusion_matrix`.
111111
112112
Returns:
113113
Update op.
114114
"""
115-
y_true = tf.cast(y_true, dtype=tf.int32)
116-
y_pred = tf.cast(y_pred, dtype=tf.int32)
115+
y_true = tf.cast(y_true, dtype=tf.int64)
116+
y_pred = tf.cast(y_pred, dtype=tf.int64)
117117

118118
if y_true.shape != y_pred.shape:
119119
raise ValueError(
120-
"Number of samples in y_true and y_pred are different")
120+
"Number of samples in `y_true` and `y_pred` are different")
121121

122122
# compute the new values of the confusion matrix
123123
new_conf_mtx = tf.math.confusion_matrix(
124124
labels=y_true,
125125
predictions=y_pred,
126126
num_classes=self.num_classes,
127-
weights=sample_weight)
127+
weights=sample_weight,
128+
dtype=tf.int64)
128129

129130
# update the values in the original confusion matrix
130131
return self.conf_mtx.assign_add(new_conf_mtx)
131132

132133
def result(self):
133134
nb_ratings = tf.shape(self.conf_mtx)[0]
134-
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32)
135+
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int64)
135136

136137
# 2. Create a weight matrix
137138
if self.weightage is None:
138-
diagonal = tf.zeros([nb_ratings], dtype=tf.int32)
139+
diagonal = tf.zeros([nb_ratings], dtype=tf.int64)
139140
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
140-
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
141-
142141
else:
143-
weight_mtx += tf.range(nb_ratings, dtype=tf.int32)
144-
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
142+
weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.int64)
143+
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)
145144

146145
if self.weightage == 'linear':
147146
weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx))
148147
else:
149148
weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2)
150-
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
149+
150+
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)
151151

152152
# 3. Get counts
153153
actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1)
@@ -161,8 +161,8 @@ def result(self):
161161
conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx)
162162
out_prod = out_prod / tf.reduce_sum(out_prod)
163163

164-
conf_mtx = tf.cast(conf_mtx, dtype=tf.float32)
165-
out_prod = tf.cast(out_prod, dtype=tf.float32)
164+
conf_mtx = tf.cast(conf_mtx, dtype=self.dtype)
165+
out_prod = tf.cast(out_prod, dtype=self.dtype)
166166

167167
# 6. Calculate Kappa score
168168
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
@@ -185,4 +185,6 @@ def reset_states(self):
185185

186186
for v in self.variables:
187187
K.set_value(
188-
v, np.zeros((self.num_classes, self.num_classes), np.int32))
188+
v,
189+
np.zeros((self.num_classes, self.num_classes),
190+
v.dtype.as_numpy_dtype))

tensorflow_addons/metrics/cohens_kappa_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,16 @@ def test_kappa_with_sample_weights(self):
130130
self.check_results([kp_obj1, kp_obj2, kp_obj3],
131131
[-0.25473321, -0.38992332, -0.60695344])
132132

133+
def test_large_values(self):
134+
y_true = [1] * 10000 + [0] * 20000 + [1] * 20000
135+
y_pred = [0] * 20000 + [1] * 30000
136+
137+
obj = CohenKappa(num_classes=2)
138+
self.evaluate(tf.compat.v1.variables_initializer(obj.variables))
139+
140+
self.evaluate(obj.update_state(y_true, y_pred))
141+
self.assertAllClose(0.166666666, obj.result())
142+
133143

134144
if __name__ == '__main__':
135145
tf.test.main()

0 commit comments

Comments
 (0)