Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 49 additions & 47 deletions tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,117 +38,117 @@ class CohenKappa(Metric):
while calculating the Cohen's Kappa score.

Usage:

```python
actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)
weights = np.array([1, 1, 2, 5, 10, 2, 3, 3], dtype=np.int32)

m = tfa.metrics.CohenKappa(num_classes=5)
m.update_state(actuals, preds)
print('Final result: ', m.result().numpy()) # Result: 0.61904764

# To use this with weights, sample_weight argument can be used.
m = tfa.metrics.CohenKappa(num_classes=5)
m.update_state(actuals, preds, sample_weight=weights)
print('Final result: ', m.result().numpy()) # Result: 0.37209308
```

Usage with tf.keras API:

```python
model = keras.models.Model(inputs, outputs)
model = tf.keras.models.Model(inputs, outputs)
model.add_metric(tfa.metrics.CohenKappa(num_classes=5)(outputs))
model.compile('sgd', loss='mse')
```

Args:
num_classes : Number of unique classes in your dataset
weightage : Weighting to be considered for calculating
kappa statistics. A valid value is one of
[None, 'linear', 'quadratic']. Defaults to None.

Returns:
kappa_score : float
The kappa statistic, which is a number between -1 and 1. The maximum
value means complete agreement; zero or lower means chance agreement.

Raises:
ValueError: If the value passed for `weightage` is invalid
i.e. not any one of [None, 'linear', 'quadratic']
"""

def __init__(self,
num_classes,
name='cohen_kappa',
weightage=None,
dtype=tf.float32):
dtype=None):
"""Creates a `CohenKappa` instance.

Args:
num_classes: Number of unique classes in your dataset.
name: (Optional) String name of the metric instance.
weightage: (Optional) Weighting to be considered for calculating
kappa statistics. A valid value is one of
[None, 'linear', 'quadratic']. Defaults to `None`.
dtype: (Optional) Data type of the metric result.
Defaults to `None`.

Raises:
ValueError: If the value passed for `weightage` is invalid
i.e. not any one of [None, 'linear', 'quadratic']
"""
super(CohenKappa, self).__init__(name=name, dtype=dtype)

if weightage not in (None, 'linear', 'quadratic'):
raise ValueError("Unknown kappa weighting type.")
else:
self.weightage = weightage

self.weightage = weightage
self.num_classes = num_classes
self.conf_mtx = self.add_weight(
'conf_mtx',
shape=(self.num_classes, self.num_classes),
initializer=tf.keras.initializers.zeros,
dtype=tf.int32)
dtype=tf.int64)

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix condition statistics.

Args:
y_true : array, shape = [n_samples]
Labels assigned by the first annotator.
y_pred : array, shape = [n_samples]
Labels assigned by the second annotator. The kappa statistic
is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't
change the value.
sample_weight(optional) : for weighting labels in confusion matrix
Default is None. The dtype for weights should be the same
as the dtype for confusion matrix. For more details,
please check tf.math.confusion_matrix.

y_true: Labels assigned by the first annotator with shape
`[num_samples,]`.
y_pred: Labels assigned by the second annotator with shape
`[num_samples,]`. The kappa statistic is symmetric,
so swapping `y_true` and `y_pred` doesn't change the value.
sample_weight (optional): for weighting labels in confusion matrix
Defaults to `None`. The dtype for weights should be the same
as the dtype for confusion matrix. For more details,
please check `tf.math.confusion_matrix`.

Returns:
Update op.
"""
y_true = tf.cast(y_true, dtype=tf.int32)
y_pred = tf.cast(y_pred, dtype=tf.int32)
y_true = tf.cast(y_true, dtype=tf.int64)
y_pred = tf.cast(y_pred, dtype=tf.int64)

if y_true.shape != y_pred.shape:
raise ValueError(
"Number of samples in y_true and y_pred are different")
"Number of samples in `y_true` and `y_pred` are different")

# compute the new values of the confusion matrix
new_conf_mtx = tf.math.confusion_matrix(
labels=y_true,
predictions=y_pred,
num_classes=self.num_classes,
weights=sample_weight)
weights=sample_weight,
dtype=tf.int64)

# update the values in the original confusion matrix
return self.conf_mtx.assign_add(new_conf_mtx)

def result(self):
nb_ratings = tf.shape(self.conf_mtx)[0]
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32)
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int64)

# 2. Create a weight matrix
if self.weightage is None:
diagonal = tf.zeros([nb_ratings], dtype=tf.int32)
diagonal = tf.zeros([nb_ratings], dtype=tf.int64)
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)

else:
weight_mtx += tf.range(nb_ratings, dtype=tf.int32)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)
weight_mtx += tf.cast(tf.range(nb_ratings), dtype=tf.int64)
weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)

if self.weightage == 'linear':
weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx))
else:
weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)

weight_mtx = tf.cast(weight_mtx, dtype=self.dtype)

# 3. Get counts
actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1)
Expand All @@ -162,8 +162,8 @@ def result(self):
conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx)
out_prod = out_prod / tf.reduce_sum(out_prod)

conf_mtx = tf.cast(conf_mtx, dtype=tf.float32)
out_prod = tf.cast(out_prod, dtype=tf.float32)
conf_mtx = tf.cast(conf_mtx, dtype=self.dtype)
out_prod = tf.cast(out_prod, dtype=self.dtype)

# 6. Calculate Kappa score
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
Expand All @@ -186,4 +186,6 @@ def reset_states(self):

for v in self.variables:
K.set_value(
v, np.zeros((self.num_classes, self.num_classes), np.int32))
v,
np.zeros((self.num_classes, self.num_classes),
v.dtype.as_numpy_dtype))
10 changes: 10 additions & 0 deletions tensorflow_addons/metrics/cohens_kappa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ def test_kappa_with_sample_weights(self):
self.check_results([kp_obj1, kp_obj2, kp_obj3],
[-0.25473321, -0.38992332, -0.60695344])

def test_large_values(self):
y_true = [1] * 10000 + [0] * 20000 + [1] * 20000
y_pred = [0] * 20000 + [1] * 30000

obj = CohenKappa(num_classes=2)
self.evaluate(tf.compat.v1.variables_initializer(obj.variables))

self.evaluate(obj.update_state(y_true, y_pred))
self.assertAllClose(0.166666666, obj.result())


if __name__ == '__main__':
tf.test.main()