-
Notifications
You must be signed in to change notification settings - Fork 617
Add kappa #267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add kappa #267
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
793cf60
add Cohens Kappa Metric
AakashKumarNain b3c6e36
add tests for Cohens Kappa Metric
AakashKumarNain 56394d1
include Cohens Kappa and tests
AakashKumarNain d5eb0da
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain d669b11
code refactor and remove extra lines
AakashKumarNain 142e61e
add separate tests for each case
AakashKumarNain 29757da
refactor code
AakashKumarNain c865196
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain 1d98713
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain ec2ed37
make the metric stateful
AakashKumarNain 2b679c6
refactor tests
AakashKumarNain 85f641e
add get_config and reset_states methods
AakashKumarNain db6bddd
refactor code and add test for sample_weight param
AakashKumarNain dda3336
add CohenKappa metric
AakashKumarNain 44d89b1
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain 520a082
format code
AakashKumarNain cb46fe5
format code
AakashKumarNain 9ea909e
make sure all tests pass
AakashKumarNain d2b87a4
fix typo in imports
AakashKumarNain b025d09
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain 319da19
Merge branch 'master' of https://github.com/tensorflow/addons into ad…
AakashKumarNain afc35b5
code refactor
AakashKumarNain 8cadec4
refactor code
AakashKumarNain 5711b18
update README
AakashKumarNain 593f7fd
fix typo
AakashKumarNain 6bf67e2
remove math import
AakashKumarNain da65a6c
fix imports
AakashKumarNain 0c31db8
fix initializer
AakashKumarNain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,12 +3,12 @@ | |
| ## Maintainers | ||
| | Submodule | Maintainers | Contact Info | | ||
| |:---------- |:------------- |:--------------| | ||
| | | | | | ||
| | cohens_kappa| Aakash Nain | [email protected]| | ||
|
|
||
| ## Contents | ||
| | Submodule | Activation | Reference | | ||
| | Submodule | Metric | Reference | | ||
| |:----------------------- |:-------------------|:---------------| | ||
| | | | | | ||
| | cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)| | ||
|
|
||
|
|
||
| ## Contribution Guidelines | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Implements Cohen's Kappa.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import tensorflow as tf | ||
| import numpy as np | ||
| import tensorflow.keras.backend as K | ||
| from tensorflow.keras.metrics import Metric | ||
| from tensorflow_addons.utils import keras_utils | ||
|
|
||
|
|
||
| @keras_utils.register_keras_custom_object | ||
| class CohenKappa(Metric): | ||
| """Computes Kappa score between two raters. | ||
|
|
||
| The score lies in the range [-1, 1]. A score of -1 represents | ||
| complete disagreement between two raters whereas a score of 1 | ||
| represents complete agreement between the two raters. | ||
| A score of 0 means agreement by chance. | ||
|
|
||
| Note: As of now, this implementation considers all labels | ||
| while calculating the Cohen's Kappa score. | ||
|
|
||
| Usage: | ||
| ```python | ||
| actuals = np.array([4, 4, 3, 4, 2, 4, 1, 1], dtype=np.int32) | ||
| preds = np.array([4, 4, 3, 4, 4, 2, 1, 1], dtype=np.int32) | ||
|
|
||
| m = tf.keras.metrics.CohenKappa(num_classes=5) | ||
| m.update_state(actuals, preds, "quadratic") | ||
| print('Final result: ', m.result().numpy()) # Result: 0.68932 | ||
| ``` | ||
| Usage with tf.keras API: | ||
| ```python | ||
| model = keras.models.Model(inputs, outputs) | ||
| model.add_metric(tf.keras.metrics.CohenKappa(num_classes=5)(outputs)) | ||
| model.compile('sgd', loss='mse') | ||
| ``` | ||
|
|
||
| Args: | ||
| num_classes : Number of unique classes in your dataset | ||
| weightage : Weighting to be considered for calculating | ||
| kappa statistics. A valid value is one of | ||
| [None, 'linear', 'quadratic']. Defaults to None. | ||
|
|
||
| Returns: | ||
| kappa_score : float | ||
| The kappa statistic, which is a number between -1 and 1. The maximum | ||
| value means complete agreement; zero or lower means chance agreement. | ||
|
|
||
| Raises: | ||
| ValueError: If the value passed for `weightage` is invalid | ||
| i.e. not any one of [None, 'linear', 'quadratic'] | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| num_classes, | ||
| name='cohen_kappa', | ||
| weightage=None, | ||
| dtype=tf.float32): | ||
| super(CohenKappa, self).__init__(name=name, dtype=dtype) | ||
|
|
||
| if weightage not in (None, 'linear', 'quadratic'): | ||
| raise ValueError("Unknown kappa weighting type.") | ||
| else: | ||
| self.weightage = weightage | ||
|
|
||
| self.num_classes = num_classes | ||
| self.conf_mtx = self.add_weight( | ||
| 'conf_mtx', | ||
| shape=(self.num_classes, self.num_classes), | ||
| initializer=tf.keras.initializers.zeros, | ||
| dtype=tf.int32) | ||
|
|
||
| def update_state(self, y_true, y_pred, sample_weight=None): | ||
| """Accumulates the confusion matrix condition statistics. | ||
|
|
||
| Args: | ||
| y_true : array, shape = [n_samples] | ||
| Labels assigned by the first annotator. | ||
| y_pred : array, shape = [n_samples] | ||
| Labels assigned by the second annotator. The kappa statistic | ||
| is symmetric, so swapping ``y_true`` and ``y_pred`` doesn't | ||
| change the value. | ||
| sample_weight(optional) : for weighting labels in confusion matrix | ||
| Default is None. The dtype for weights should be the same | ||
| as the dtype for confusion matrix. For more details, | ||
| please check tf.math.confusion_matrix. | ||
|
|
||
|
|
||
| Returns: | ||
| Update op. | ||
| """ | ||
| y_true = tf.cast(y_true, dtype=tf.int32) | ||
| y_pred = tf.cast(y_pred, dtype=tf.int32) | ||
|
|
||
| if y_true.shape != y_pred.shape: | ||
| raise ValueError( | ||
| "Number of samples in y_true and y_pred are different") | ||
|
|
||
| # compute the new values of the confusion matrix | ||
| new_conf_mtx = tf.math.confusion_matrix( | ||
| labels=y_true, | ||
| predictions=y_pred, | ||
| num_classes=self.num_classes, | ||
| weights=sample_weight) | ||
|
|
||
| # update the values in the original confusion matrix | ||
| return self.conf_mtx.assign_add(new_conf_mtx) | ||
|
|
||
| def result(self): | ||
| nb_ratings = tf.shape(self.conf_mtx)[0] | ||
| weight_mtx = tf.ones([nb_ratings, nb_ratings], dtype=tf.int32) | ||
|
|
||
| # 2. Create a weight matrix | ||
| if self.weightage is None: | ||
| diagonal = tf.zeros([nb_ratings], dtype=tf.int32) | ||
| weight_mtx = tf.linalg.set_diag(weight_mtx, diagonal=diagonal) | ||
| weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) | ||
|
|
||
| else: | ||
| weight_mtx += tf.range(nb_ratings, dtype=tf.int32) | ||
| weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) | ||
|
|
||
| if self.weightage == 'linear': | ||
| weight_mtx = tf.abs(weight_mtx - tf.transpose(weight_mtx)) | ||
| else: | ||
| weight_mtx = tf.pow((weight_mtx - tf.transpose(weight_mtx)), 2) | ||
| weight_mtx = tf.cast(weight_mtx, dtype=tf.float32) | ||
|
|
||
| # 3. Get counts | ||
| actual_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=1) | ||
| pred_ratings_hist = tf.reduce_sum(self.conf_mtx, axis=0) | ||
|
|
||
| # 4. Get the outer product | ||
| out_prod = pred_ratings_hist[..., None] * \ | ||
| actual_ratings_hist[None, ...] | ||
|
|
||
| # 5. Normalize the confusion matrix and outer product | ||
| conf_mtx = self.conf_mtx / tf.reduce_sum(self.conf_mtx) | ||
| out_prod = out_prod / tf.reduce_sum(out_prod) | ||
|
|
||
| conf_mtx = tf.cast(conf_mtx, dtype=tf.float32) | ||
| out_prod = tf.cast(out_prod, dtype=tf.float32) | ||
|
|
||
| # 6. Calculate Kappa score | ||
| numerator = tf.reduce_sum(conf_mtx * weight_mtx) | ||
| denominator = tf.reduce_sum(out_prod * weight_mtx) | ||
| kp = 1 - (numerator / denominator) | ||
| return kp | ||
|
|
||
| def get_config(self): | ||
| """Returns the serializable config of the metric.""" | ||
|
|
||
| config = { | ||
| "num_classes": self.num_classes, | ||
| "weightage": self.weightage, | ||
| } | ||
| base_config = super(CohenKappa, self).get_config() | ||
| return dict(list(base_config.items()) + list(config.items())) | ||
|
|
||
| def reset_states(self): | ||
| """Resets all of the metric state variables.""" | ||
|
|
||
| for v in self.variables: | ||
| K.set_value( | ||
| v, np.zeros((self.num_classes, self.num_classes), np.int32)) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| # Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Tests for Cohen's Kappa Metric.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow_addons.metrics import CohenKappa | ||
| from tensorflow_addons.utils import test_utils | ||
|
|
||
|
|
||
| @test_utils.run_all_in_graph_and_eager_modes | ||
| class CohenKappaTest(tf.test.TestCase): | ||
| def test_config(self): | ||
| kp_obj = CohenKappa(name='cohen_kappa', num_classes=5) | ||
| self.assertEqual(kp_obj.name, 'cohen_kappa') | ||
| self.assertEqual(kp_obj.dtype, tf.float32) | ||
| self.assertEqual(kp_obj.num_classes, 5) | ||
|
|
||
| # Check save and restore config | ||
| kb_obj2 = CohenKappa.from_config(kp_obj.get_config()) | ||
| self.assertEqual(kb_obj2.name, 'cohen_kappa') | ||
| self.assertEqual(kb_obj2.dtype, tf.float32) | ||
| self.assertEqual(kp_obj.num_classes, 5) | ||
|
|
||
| def initialize_vars(self): | ||
| kp_obj1 = CohenKappa(num_classes=5) | ||
| kp_obj2 = CohenKappa(num_classes=5, weightage='linear') | ||
| kp_obj3 = CohenKappa(num_classes=5, weightage='quadratic') | ||
|
|
||
| self.evaluate(tf.compat.v1.variables_initializer(kp_obj1.variables)) | ||
| self.evaluate(tf.compat.v1.variables_initializer(kp_obj2.variables)) | ||
| self.evaluate(tf.compat.v1.variables_initializer(kp_obj3.variables)) | ||
| return kp_obj1, kp_obj2, kp_obj3 | ||
|
|
||
| def update_obj_states(self, obj1, obj2, obj3, actuals, preds, weights): | ||
| update_op1 = obj1.update_state(actuals, preds, sample_weight=weights) | ||
| update_op2 = obj2.update_state(actuals, preds, sample_weight=weights) | ||
| update_op3 = obj3.update_state(actuals, preds, sample_weight=weights) | ||
|
|
||
| self.evaluate(update_op1) | ||
| self.evaluate(update_op2) | ||
| self.evaluate(update_op3) | ||
|
|
||
| def check_results(self, objs, values): | ||
| obj1, obj2, obj3 = objs | ||
| val1, val2, val3 = values | ||
|
|
||
| self.assertAllClose(val1, self.evaluate(obj1.result()), atol=1e-5) | ||
| self.assertAllClose(val2, self.evaluate(obj2.result()), atol=1e-5) | ||
| self.assertAllClose(val3, self.evaluate(obj3.result()), atol=1e-5) | ||
|
|
||
| def test_kappa_random_score(self): | ||
| actuals = [4, 4, 3, 4, 2, 4, 1, 1] | ||
| preds = [4, 4, 3, 4, 4, 2, 1, 1] | ||
| actuals = tf.constant(actuals, dtype=tf.int32) | ||
| preds = tf.constant(preds, dtype=tf.int32) | ||
|
|
||
| # Initialize | ||
| kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars() | ||
|
|
||
| # Update | ||
| self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None) | ||
|
|
||
| # Check results | ||
| self.check_results([kp_obj1, kp_obj2, kp_obj3], | ||
| [0.61904761, 0.62790697, 0.68932038]) | ||
|
|
||
| def test_kappa_perfect_score(self): | ||
| actuals = [4, 4, 3, 3, 2, 2, 1, 1] | ||
| preds = [4, 4, 3, 3, 2, 2, 1, 1] | ||
| actuals = tf.constant(actuals, dtype=tf.int32) | ||
| preds = tf.constant(preds, dtype=tf.int32) | ||
|
|
||
| # Initialize | ||
| kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars() | ||
|
|
||
| # Update | ||
| self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None) | ||
|
|
||
| # Check results | ||
| self.check_results([kp_obj1, kp_obj2, kp_obj3], [1.0, 1.0, 1.0]) | ||
|
|
||
| def test_kappa_worse_than_random(self): | ||
| actuals = [4, 4, 3, 3, 2, 2, 1, 1] | ||
| preds = [1, 2, 4, 1, 3, 3, 4, 4] | ||
| actuals = tf.constant(actuals, dtype=tf.int32) | ||
| preds = tf.constant(preds, dtype=tf.int32) | ||
|
|
||
| # Initialize | ||
| kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars() | ||
|
|
||
| # Update | ||
| self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, None) | ||
|
|
||
| # check results | ||
| self.check_results([kp_obj1, kp_obj2, kp_obj3], | ||
| [-0.3333333, -0.52380952, -0.72727272]) | ||
|
|
||
| def test_kappa_with_sample_weights(self): | ||
| actuals = [4, 4, 3, 3, 2, 2, 1, 1] | ||
| preds = [1, 2, 4, 1, 3, 3, 4, 4] | ||
| weights = [1, 1, 2, 5, 10, 2, 3, 3] | ||
| actuals = tf.constant(actuals, dtype=tf.int32) | ||
| preds = tf.constant(preds, dtype=tf.int32) | ||
| weights = tf.constant(weights, dtype=tf.int32) | ||
|
|
||
| # Initialize | ||
| kp_obj1, kp_obj2, kp_obj3 = self.initialize_vars() | ||
|
|
||
| # Update | ||
| self.update_obj_states(kp_obj1, kp_obj2, kp_obj3, actuals, preds, | ||
| weights) | ||
|
|
||
| # check results | ||
| self.check_results([kp_obj1, kp_obj2, kp_obj3], | ||
| [-0.25473321, -0.38992332, -0.60695344]) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| tf.test.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.