diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index 541ec88bb9..f176a95f67 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "cohens_kappa.py", + "f1_scores.py", "r_square.py", ], srcs_version = "PY2AND3", @@ -40,3 +41,16 @@ py_test( ":metrics", ], ) + +py_test( + name = "f1_test", + size = "small", + srcs = [ + "f1_test.py", + ], + main = "f1_test.py", + srcs_version = "PY2AND3", + deps = [ + ":metrics", + ], +) diff --git a/tensorflow_addons/metrics/README.md b/tensorflow_addons/metrics/README.md index 4558983dee..0147d0d0ee 100644 --- a/tensorflow_addons/metrics/README.md +++ b/tensorflow_addons/metrics/README.md @@ -4,12 +4,14 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | cohens_kappa| Aakash Nain | aakashnain@outlook.com| +| f1_scores| Saishruthi Swaminathan | saishruthi.tn@gmail.com| | r_square| Saishruthi Swaminathan| saishruthi.tn@gmail.com| ## Contents | Submodule | Metric | Reference | |:----------------------- |:-------------------|:---------------| | cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)| +| f1_scores| F1 micro, macro and weighted| [F1 Score](https://en.wikipedia.org/wiki/F1_score)| | r_square| RSquare|[R-Sqaure](https://en.wikipedia.org/wiki/Coefficient_of_determination)| diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py old mode 100644 new mode 100755 index 6bb4197fee..985341d9c5 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -19,4 +19,5 @@ from __future__ import print_function from tensorflow_addons.metrics.cohens_kappa import CohenKappa +from tensorflow_addons.metrics.f1_scores import F1Score from tensorflow_addons.metrics.r_square import RSquare diff --git a/tensorflow_addons/metrics/f1_scores.py b/tensorflow_addons/metrics/f1_scores.py new file mode 100755 index 0000000000..220583d254 --- /dev/null +++ b/tensorflow_addons/metrics/f1_scores.py @@ -0,0 +1,208 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements F1 scores.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.keras.metrics import Metric +import numpy as np + + +class F1Score(Metric): + """Calculates F1 micro, macro or weighted based on the user's choice. + + F1 score is the weighted average of precision and + recall. Output range is [0, 1]. This works for both + multi-class and multi-label classification. + + + Args: + num_classes : Number of unique classes in the dataset. + average : Type of averaging to be performed on data. + Acceptable values are None, micro, macro and + weighted. + Default value is None. + + Returns: + F1 score: float + + Raises: + ValueError: If the `average` has values other than + [None, micro, macro. weighted]. + + `average` parameter behavior: + + 1. If `None` is specified as an input, scores for each + class are returned. + + 2. If `micro` is specified, metrics like true positivies, + false positives and false negatives are computed + globally. + + 3. If `macro` is specified, metrics like true positivies, + false positives and false negatives are computed for + each class and their unweighted mean is returned. + Imbalance in dataset is not taken into account for + calculating the score + + 4. If `weighted` is specified, metrics are computed for + each class and returns the mean weighted by the + number of true instances in each class taking data + imbalance into account. + + Usage: + ```python + actuals = tf.constant([[1, 1, 0],[1, 0, 0]], + dtype=tf.int32) + preds = tf.constant([[1, 0, 0],[1, 0, 1]], + dtype=tf.int32) + output = tf.keras.metrics.F1Score(num_classes=3, + average='micro') + output.update_state(actuals, predictions) + print('F1 Micro score is: ', + output.result().numpy()) # 0.6666667 + ``` + """ + + def __init__(self, + num_classes, + average=None, + name='f1_score', + dtype=tf.float32): + super(F1Score, self).__init__(name=name) + self.num_classes = num_classes + if average not in (None, 'micro', 'macro', 'weighted'): + raise ValueError("Unknown average type. Acceptable values " + "are: [micro, macro, weighted]") + else: + self.average = average + if self.average == 'micro': + self.axis = None + else: + self.axis = 0 + if self.average == 'micro': + self.true_positives = self.add_weight( + 'true_positives', + shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_positives = self.add_weight( + 'false_positives', + shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_negatives = self.add_weight( + 'false_negatives', + shape=[], + initializer='zeros', + dtype=tf.float32) + else: + self.true_positives = self.add_weight( + 'true_positives', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_positives = self.add_weight( + 'false_positives', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_negatives = self.add_weight( + 'false_negatives', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.weights_intermediate = self.add_weight( + 'weights', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + + def update_state(self, y_true, y_pred): + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + # true positive + self.true_positives.assign_add( + tf.cast( + tf.math.count_nonzero(y_pred * y_true, axis=self.axis), + tf.float32)) + # false positive + self.false_positives.assign_add( + tf.cast( + tf.math.count_nonzero(y_pred * (y_true - 1), axis=self.axis), + tf.float32)) + # false negative + self.false_negatives.assign_add( + tf.cast( + tf.math.count_nonzero((y_pred - 1) * y_true, axis=self.axis), + tf.float32)) + if self.average == 'weighted': + # variable to hold intermediate weights + self.weights_intermediate.assign_add( + tf.cast(tf.reduce_sum(y_true, axis=self.axis), tf.float32)) + + def result(self): + p_sum = tf.cast(self.true_positives + self.false_positives, tf.float32) + # calculate precision + precision = tf.math.divide_no_nan(self.true_positives, p_sum) + + r_sum = tf.cast(self.true_positives + self.false_negatives, tf.float32) + # calculate recall + recall = tf.math.divide_no_nan(self.true_positives, r_sum) + + mul_value = 2 * precision * recall + add_value = precision + recall + f1_int = tf.math.divide_no_nan(mul_value, add_value) + # f1 score + if self.average is not None: + f1_score = tf.reduce_mean(f1_int) + else: + f1_score = f1_int + # condition for weighted f1 score + if self.average == 'weighted': + f1_int_weights = tf.math.divide_no_nan( + self.weights_intermediate, + tf.reduce_sum(self.weights_intermediate)) + # weighted f1 score calculation + f1_score = tf.reduce_sum(f1_int * f1_int_weights) + + return f1_score + + def get_config(self): + """Returns the serializable config of the metric.""" + + config = { + "num_classes": self.num_classes, + "average": self.average, + } + base_config = super(F1Score, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def reset_states(self): + # reset state of the variables to zero + if self.average == 'micro': + self.true_positives.assign(0) + self.false_positives.assign(0) + self.false_negatives.assign(0) + else: + self.true_positives.assign(np.zeros(self.num_classes), np.float32) + self.false_positives.assign(np.zeros(self.num_classes), np.float32) + self.false_negatives.assign(np.zeros(self.num_classes), np.float32) + self.weights_intermediate.assign( + np.zeros(self.num_classes), np.float32) diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py new file mode 100755 index 0000000000..323b8432f9 --- /dev/null +++ b/tensorflow_addons/metrics/f1_test.py @@ -0,0 +1,127 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests F1 metrics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.metrics import F1Score +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class F1ScoreTest(tf.test.TestCase): + def test_config(self): + f1_obj = F1Score(name='f1_score', num_classes=3) + self.assertEqual(f1_obj.name, 'f1_score') + self.assertEqual(f1_obj.dtype, tf.float32) + self.assertEqual(f1_obj.num_classes, 3) + # Check save and restore config + f1_obj2 = F1Score.from_config(f1_obj.get_config()) + self.assertEqual(f1_obj2.name, 'f1_score') + self.assertEqual(f1_obj2.dtype, tf.float32) + self.assertEqual(f1_obj2.num_classes, 3) + + def initialize_vars(self): + f1_micro = F1Score(num_classes=3, average='micro') + f1_macro = F1Score(num_classes=3, average='macro') + f1_weighted = F1Score(num_classes=3, average='weighted') + + self.evaluate(tf.compat.v1.variables_initializer(f1_micro.variables)) + self.evaluate(tf.compat.v1.variables_initializer(f1_macro.variables)) + self.evaluate( + tf.compat.v1.variables_initializer(f1_weighted.variables)) + return f1_micro, f1_macro, f1_weighted + + def initialize_vars_none(self): + f1_none = F1Score(num_classes=3, average=None) + + self.evaluate(tf.compat.v1.variables_initializer(f1_none.variables)) + return f1_none + + def update_obj_states(self, f1_micro, f1_macro, f1_weighted, actuals, + preds): + update_micro = f1_micro.update_state(actuals, preds) + update_macro = f1_macro.update_state(actuals, preds) + update_weighted = f1_weighted.update_state(actuals, preds) + self.evaluate(update_micro) + self.evaluate(update_macro) + self.evaluate(update_weighted) + + def update_obj_states_none(self, f1_none, actuals, preds): + update_none = f1_none.update_state(actuals, preds) + self.evaluate(update_none) + + def check_results(self, obj, value): + self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) + + def test_f1_perfect_score(self): + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) + # Initialize + f1_micro, f1_macro, f1_weighted = self.initialize_vars() + # Update + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) + # Check results + self.check_results(f1_micro, 1.0) + self.check_results(f1_macro, 1.0) + self.check_results(f1_weighted, 1.0) + + def test_f1_worst_score(self): + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) + # Initialize + f1_micro, f1_macro, f1_weighted = self.initialize_vars() + # Update + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) + # Check results + self.check_results(f1_micro, 0.0) + self.check_results(f1_macro, 0.0) + self.check_results(f1_weighted, 0.0) + + def test_f1_random_score(self): + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) + # Initialize + f1_micro, f1_macro, f1_weighted = self.initialize_vars() + # Update + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) + # Check results + self.check_results(f1_micro, 0.6666666) + self.check_results(f1_macro, 0.6555555) + self.check_results(f1_weighted, 0.6777777) + + def test_f1_none_score(self): + actuals = tf.constant( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], + dtype=tf.int32) + preds = tf.constant( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]], + dtype=tf.int32) + # Initialize + f1_none = self.initialize_vars_none() + # Update + self.update_obj_states_none(f1_none, actuals, preds) + # Check results + self.check_results(f1_none, [0.8, 0.6666667, 1.]) + + +if __name__ == '__main__': + tf.test.main()