Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
793cf60
add Cohens Kappa Metric
AakashKumarNain Jun 1, 2019
b3c6e36
add tests for Cohens Kappa Metric
AakashKumarNain Jun 1, 2019
56394d1
include Cohens Kappa and tests
AakashKumarNain Jun 1, 2019
d5eb0da
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 3, 2019
d669b11
code refactor and remove extra lines
AakashKumarNain Jun 3, 2019
142e61e
add separate tests for each case
AakashKumarNain Jun 3, 2019
29757da
refactor code
AakashKumarNain Jun 3, 2019
c865196
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 9, 2019
1d98713
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 10, 2019
ec2ed37
make the metric stateful
AakashKumarNain Jun 10, 2019
2b679c6
refactor tests
AakashKumarNain Jun 10, 2019
85f641e
add get_config and reset_states methods
AakashKumarNain Jun 11, 2019
db6bddd
refactor code and add test for sample_weight param
AakashKumarNain Jun 11, 2019
dda3336
add CohenKappa metric
AakashKumarNain Jun 11, 2019
44d89b1
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 11, 2019
520a082
format code
AakashKumarNain Jun 11, 2019
cb46fe5
format code
AakashKumarNain Jun 11, 2019
9ea909e
make sure all tests pass
AakashKumarNain Jun 11, 2019
d2b87a4
fix typo in imports
AakashKumarNain Jun 11, 2019
b025d09
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 11, 2019
319da19
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain Jun 11, 2019
afc35b5
code refactor
AakashKumarNain Jun 11, 2019
8cadec4
refactor code
AakashKumarNain Jun 11, 2019
5711b18
update README
AakashKumarNain Jun 11, 2019
593f7fd
fix typo
AakashKumarNain Jun 11, 2019
6bf67e2
remove math import
AakashKumarNain Jun 12, 2019
da65a6c
fix imports
AakashKumarNain Jun 12, 2019
0c31db8
fix initializer
AakashKumarNain Jun 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@ py_library(
name = "metrics",
srcs = [
"__init__.py",
"cohens_kappa.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow_addons/utils",
],
)

py_test(
name = "cohens_kappa_test",
size = "small",
srcs = [
"cohens_kappa_test.py",
],
main = "cohens_kappa_test.py",
srcs_version = "PY2AND3",
deps = [
":metrics",
],
)
6 changes: 3 additions & 3 deletions tensorflow_addons/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
## Maintainers
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| | | |
| cohens_kappa| Aakash Nain | [email protected]|

## Contents
| Submodule | Activation | Reference |
| Submodule | Metric | Reference |
|:----------------------- |:-------------------|:---------------|
| | | |
| cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)|


## Contribution Guidelines
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow_addons.metrics.cohens_kappa import CohenKappa
183 changes: 183 additions & 0 deletions tensorflow_addons/metrics/cohens_kappa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements Cohen's Kappa."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
class CohenKappa(Metric):
"""Computes Kappa score between two raters.

The score lies in the range [-1, 1]. A score of -1 represents
complete disagreement between two raters whereas a score of 1
represents complete agreement between the two raters.
A score of 0 means agreement by chance.

Note: As of now, this implementation considers all labels
while calculating the Cohen's Kappa score.

Usage:
```python
actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32)
preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32)

m = tf.keras.metrics.CohenKappa(num_classes=5)
m.update_state(actuals, preds, "quadratic")
print('Final result: ', m.result().numpy()) # Result: 0.68932
```
Usage with tf.keras API:
```python
model = keras.models.Model(inputs, outputs)
model.add_metric(tf.keras.metrics.CohenKappa(num_classes=5)(outputs))
model.compile('sgd', loss='mse')
```

Args:
num_classes : Number of unique classes in your dataset
weightage : Weighting to be considered for calculating
kappa statistics. A valid value is one of
[None, 'linear', 'quadratic']. Defaults to None.

Returns:
kappa_score : float
The kappa statistic, which is a number between -1 and 1. The maximum
value means complete agreement; zero or lower means chance agreement.

Raises:
ValueError: If the value passed for `weightage` is invalid
i.e. not any one of [None, 'linear', 'quadratic']
"""

def __init__(self,
num_classes,
name='cohen_kappa',
weightage=None,
dtype=tf.float32):
super(CohenKappa, self).__init__(name=name, dtype=dtype)

if weightage not in (None, 'linear', 'quadratic'):
raise ValueError("Unknown kappa weighting type.")
else:
self.weightage = weightage

self.num_classes = num_classes
self.conf_mtx = self.add_weight(
'conf_mtx',
shape=(self.num_classes, self.num_classes),
initializer=tf.keras.initializers.zeros,
dtype=tf.int32)

def update_state(self, y_true, y_pred, sample_weight=None):
"""Accumulates the confusion matrix condition statistics.

Args:
y_true : array, shape = [n_samples]
Labels assigned by the first annotator.
y_pred : array, shape = [n_samples]
Labels assigned by the second annotator. The kappa statistic
is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't
change the value.
sample_weight(optional) : for weighting labels in confusion matrix
Default is None. The dtype for weights should be the same
as the dtype for confusion matrix. For more details,
please check tf.math.confusion_matrix.


Returns:
Update op.
"""
y_true = tf.cast(y_true, dtype=tf.int32)
y_pred = tf.cast(y_pred, dtype=tf.int32)

if y_true.shape != y_pred.shape:
raise ValueError(
"Number of samples in y_true and y_pred are different")

# compute the new values of the confusion matrix
new_conf_mtx = tf.math.confusion_matrix(
labels=y_true,
predictions=y_pred,
num_classes=self.num_classes,
weights=sample_weight)

# update the values in the original confusion matrix
return self.conf_mtx.assign_add(new_conf_mtx)

def result(self):
nb_ratings = tf.shape(self.conf_mtx)[0]
weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32)

# 2. Create a weight matrix
if self.weightage is None:
diagonal = tf.zeros([nb_ratings], dtype=tf.int32)
weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)

else:
weight_mtx += tf.range(nb_ratings, dtype=tf.int32)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)

if self.weightage == 'linear':
weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx))
else:
weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2)
weight_mtx = tf.cast(weight_mtx, dtype=tf.float32)

# 3. Get counts
actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1)
pred_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=0)

# 4. Get the outer product
out_prod = pred_ratings_hist[..., None] * \
actual_ratings_hist[None, ...]

# 5. Normalize the confusion matrix and outer product
conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx)
out_prod = out_prod / tf.reduce_sum(out_prod)

conf_mtx = tf.cast(conf_mtx, dtype=tf.float32)
out_prod = tf.cast(out_prod, dtype=tf.float32)

# 6. Calculate Kappa score
numerator = tf.reduce_sum(conf_mtx * weight_mtx)
denominator = tf.reduce_sum(out_prod * weight_mtx)
kp = 1 - (numerator / denominator)
return kp

def get_config(self):
"""Returns the serializable config of the metric."""

config = {
"num_classes": self.num_classes,
"weightage": self.weightage,
}
base_config = super(CohenKappa, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def reset_states(self):
"""Resets all of the metric state variables."""

for v in self.variables:
K.set_value(
v, np.zeros((self.num_classes, self.num_classes), np.int32))
135 changes: 135 additions & 0 deletions tensorflow_addons/metrics/cohens_kappa_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Cohen's Kappa Metric."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.metrics import CohenKappa
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class CohenKappaTest(tf.test.TestCase):
def test_config(self):
kp_obj = CohenKappa(name='cohen_kappa', num_classes=5)
self.assertEqual(kp_obj.name, 'cohen_kappa')
self.assertEqual(kp_obj.dtype, tf.float32)
self.assertEqual(kp_obj.num_classes, 5)

# Check save and restore config
kb_obj2 = CohenKappa.from_config(kp_obj.get_config())
self.assertEqual(kb_obj2.name, 'cohen_kappa')
self.assertEqual(kb_obj2.dtype, tf.float32)
self.assertEqual(kp_obj.num_classes, 5)

def initialize_vars(self):
kp_obj1 = CohenKappa(num_classes=5)
kp_obj2 = CohenKappa(num_classes=5, weightage='linear')
kp_obj3 = CohenKappa(num_classes=5, weightage='quadratic')

self.evaluate(tf.compat.v1.variables_initializer(kp_obj1.variables))
self.evaluate(tf.compat.v1.variables_initializer(kp_obj2.variables))
self.evaluate(tf.compat.v1.variables_initializer(kp_obj3.variables))
return kp_obj1, kp_obj2, kp_obj3

def update_obj_states(self, obj1, obj2, obj3, actuals, preds, weights):
update_op1 = obj1.update_state(actuals, preds, sample_weight=weights)
update_op2 = obj2.update_state(actuals, preds, sample_weight=weights)
update_op3 = obj3.update_state(actuals, preds, sample_weight=weights)

self.evaluate(update_op1)
self.evaluate(update_op2)
self.evaluate(update_op3)

def check_results(self, objs, values):
obj1, obj2, obj3 = objs
val1, val2, val3 = values

self.assertAllClose(val1, self.evaluate(obj1.result()), atol=1e-5)
self.assertAllClose(val2, self.evaluate(obj2.result()), atol=1e-5)
self.assertAllClose(val3, self.evaluate(obj3.result()), atol=1e-5)

def test_kappa_random_score(self):
actuals = [4, 4, 3, 4, 2, 4, 1, 1]
preds = [4, 4, 3, 4, 4, 2, 1, 1]
actuals = tf.constant(actuals, dtype=tf.int32)
preds = tf.constant(preds, dtype=tf.int32)

# Initialize
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()

# Update
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)

# Check results
self.check_results([kp_obj1, kp_obj2, kp_obj3],
[0.61904761, 0.62790697, 0.68932038])

def test_kappa_perfect_score(self):
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
preds = [4, 4, 3, 3, 2, 2, 1, 1]
actuals = tf.constant(actuals, dtype=tf.int32)
preds = tf.constant(preds, dtype=tf.int32)

# Initialize
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()

# Update
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)

# Check results
self.check_results([kp_obj1, kp_obj2, kp_obj3], [1.0, 1.0, 1.0])

def test_kappa_worse_than_random(self):
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
preds = [1, 2, 4, 1, 3, 3, 4, 4]
actuals = tf.constant(actuals, dtype=tf.int32)
preds = tf.constant(preds, dtype=tf.int32)

# Initialize
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()

# Update
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None)

# check results
self.check_results([kp_obj1, kp_obj2, kp_obj3],
[-0.3333333, -0.52380952, -0.72727272])

def test_kappa_with_sample_weights(self):
actuals = [4, 4, 3, 3, 2, 2, 1, 1]
preds = [1, 2, 4, 1, 3, 3, 4, 4]
weights = [1, 1, 2, 5, 10, 2, 3, 3]
actuals = tf.constant(actuals, dtype=tf.int32)
preds = tf.constant(preds, dtype=tf.int32)
weights = tf.constant(weights, dtype=tf.int32)

# Initialize
kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars()

# Update
self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds,
weights)

# check results
self.check_results([kp_obj1, kp_obj2, kp_obj3],
[-0.25473321, -0.38992332, -0.60695344])


if __name__ == '__main__':
tf.test.main()