From 479451ef02adb828f2303566ad0373eb3f5bbc1c Mon Sep 17 00:00:00 2001 From: saishruthi Date: Wed, 12 Jun 2019 11:17:18 -0700 Subject: [PATCH 1/9] f1-macro --- tensorflow_addons/metrics/f1_macro.py | 63 +++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tensorflow_addons/metrics/f1_macro.py diff --git a/tensorflow_addons/metrics/f1_macro.py b/tensorflow_addons/metrics/f1_macro.py new file mode 100644 index 0000000000..d74ec30608 --- /dev/null +++ b/tensorflow_addons/metrics/f1_macro.py @@ -0,0 +1,63 @@ +import tensorflow as tf +from tensorflow.keras.metrics import Metric +import numpy as np + + +class F1_score(Metric): + """ + Computes F1 macro score + """ + def __init__(self, num_classes, name='f1-score'): + super(F1_score, self).__init__(name=name) + self.num_classes = num_classes + self.true_positives_col = self.add_weight('TP-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_positives_col = self.add_weight('FP-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_negatives_col = self.add_weight('FN-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + + def update_state(self, y_true, y_pred): + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + # true positive across column + self.true_positives_col.assign_add(tf.cast(tf.math.count_nonzero( + y_pred * y_true, axis=0), tf.float32)) + # false positive across column + self.false_positives_col.assign_add(tf.cast(tf.math.count_nonzero( + y_pred * (y_true - 1), axis=0), tf.float32)) + # false negative across column + self.false_negatives_col.assign_add(tf.cast( + tf.math.count_nonzero((y_pred - 1) * y_true, axis=0), tf.float32)) + + def result(self): + p_sum = tf.cast(self.true_positives_col + self.false_positives_col, + tf.float32) + precision_macro = tf.cast(tf.compat.v1.div_no_nan( + self.true_positives_col, p_sum), tf.float32) + + r_sum = tf.cast(self.true_positives_col + self.false_negatives_col, + tf.float32) + recall_macro = tf.cast(tf.compat.v1.div_no_nan( + self.true_positives_col, r_sum), tf.float32) + + mul_value = 2 * precision_macro * recall_macro + add_value = precision_macro + recall_macro + f1_macro = tf.cast(tf.compat.v1.div_no_nan(mul_value, add_value), + tf.float32) + + f1_macro = tf.reduce_mean(f1_macro) + + return f1_macro + + def reset_states(self): + self.true_positives_col.assign(np.zeros(self.num_classes), np.float32) + self.false_positives_col.assign(np.zeros(self.num_classes), np.float32) + self.false_negatives_col.assign(np.zeros(self.num_classes), np.float32) From 9851f724b5e5d919ca81cf6e6e0201d3fb57b04a Mon Sep 17 00:00:00 2001 From: saishruthi Date: Thu, 13 Jun 2019 15:14:19 -0700 Subject: [PATCH 2/9] Updating complete F1 score --- tensorflow_addons/metrics/f1_macro.py | 63 ---------- tensorflow_addons/metrics/f1_scores.py | 166 +++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 63 deletions(-) delete mode 100644 tensorflow_addons/metrics/f1_macro.py create mode 100644 tensorflow_addons/metrics/f1_scores.py diff --git a/tensorflow_addons/metrics/f1_macro.py b/tensorflow_addons/metrics/f1_macro.py deleted file mode 100644 index d74ec30608..0000000000 --- a/tensorflow_addons/metrics/f1_macro.py +++ /dev/null @@ -1,63 +0,0 @@ -import tensorflow as tf -from tensorflow.keras.metrics import Metric -import numpy as np - - -class F1_score(Metric): - """ - Computes F1 macro score - """ - def __init__(self, num_classes, name='f1-score'): - super(F1_score, self).__init__(name=name) - self.num_classes = num_classes - self.true_positives_col = self.add_weight('TP-class', - shape=[self.num_classes], - initializer='zeros', - dtype=tf.float32) - self.false_positives_col = self.add_weight('FP-class', - shape=[self.num_classes], - initializer='zeros', - dtype=tf.float32) - self.false_negatives_col = self.add_weight('FN-class', - shape=[self.num_classes], - initializer='zeros', - dtype=tf.float32) - - def update_state(self, y_true, y_pred): - y_true = tf.cast(y_true, tf.int32) - y_pred = tf.cast(y_pred, tf.int32) - - # true positive across column - self.true_positives_col.assign_add(tf.cast(tf.math.count_nonzero( - y_pred * y_true, axis=0), tf.float32)) - # false positive across column - self.false_positives_col.assign_add(tf.cast(tf.math.count_nonzero( - y_pred * (y_true - 1), axis=0), tf.float32)) - # false negative across column - self.false_negatives_col.assign_add(tf.cast( - tf.math.count_nonzero((y_pred - 1) * y_true, axis=0), tf.float32)) - - def result(self): - p_sum = tf.cast(self.true_positives_col + self.false_positives_col, - tf.float32) - precision_macro = tf.cast(tf.compat.v1.div_no_nan( - self.true_positives_col, p_sum), tf.float32) - - r_sum = tf.cast(self.true_positives_col + self.false_negatives_col, - tf.float32) - recall_macro = tf.cast(tf.compat.v1.div_no_nan( - self.true_positives_col, r_sum), tf.float32) - - mul_value = 2 * precision_macro * recall_macro - add_value = precision_macro + recall_macro - f1_macro = tf.cast(tf.compat.v1.div_no_nan(mul_value, add_value), - tf.float32) - - f1_macro = tf.reduce_mean(f1_macro) - - return f1_macro - - def reset_states(self): - self.true_positives_col.assign(np.zeros(self.num_classes), np.float32) - self.false_positives_col.assign(np.zeros(self.num_classes), np.float32) - self.false_negatives_col.assign(np.zeros(self.num_classes), np.float32) diff --git a/tensorflow_addons/metrics/f1_scores.py b/tensorflow_addons/metrics/f1_scores.py new file mode 100644 index 0000000000..9a16ce23f0 --- /dev/null +++ b/tensorflow_addons/metrics/f1_scores.py @@ -0,0 +1,166 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements F1 scores.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.keras.metrics import Metric +import numpy as np + + +class F1_micro(Metric): + """ + Calculates F1 micro score + """ + def __init__(self, name='f1-score'): + super(F1_micro, self).__init__(name=name) + self.true_positives = self.add_weight('TP-class', shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_positives = self.add_weight('FP-class', shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_negatives = self.add_weight('FN-class', shape=[], + initializer='zeros', + dtype=tf.float32) + + def update_state(self, y_true, y_pred): + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + # true positive across column + self.true_positives.assign_add(tf.cast(tf.math.count_nonzero( + y_pred * y_true, axis=None), tf.float32)) + # false positive across column + self.false_positives.assign_add(tf.cast(tf.math.count_nonzero( + y_pred * (y_true - 1), axis=None), tf.float32)) + # false negative across column + self.false_negatives.assign_add(tf.cast( + tf.math.count_nonzero((y_pred - 1) * y_true, axis=None), + tf.float32)) + + def result(self): + p_sum = tf.cast(self.true_positives + self.false_positives, + tf.float32) + # precision calculation + precision_micro = tf.cast(tf.math.divide_no_nan( + self.true_positives, p_sum), tf.float32) + + r_sum = tf.cast(self.true_positives + self.false_negatives, + tf.float32) + # recall calculation + recall_micro = tf.cast(tf.math.divide_no_nan( + self.true_positives, r_sum), tf.float32) + + mul_value = 2 * precision_micro * recall_micro + add_value = precision_micro + recall_micro + f1_micro = tf.cast(tf.math.divide_no_nan(mul_value, add_value), + tf.float32) + # f1 score calculation + f1_micro = tf.reduce_mean(f1_micro) + + return f1_micro + + def reset_states(self): + # reset state of the variables to zero + self.true_positives.assign(0) + self.false_positives.assign(0) + self.false_negatives.assign(0) + + +class F1_macro_and_weighted(Metric): + """ + Calculates F1 macro or weighted based on the user's choice + """ + def __init__(self, num_classes, average, + name='f1-macro-and-weighted-score'): + super(F1_macro_and_weighted, self).__init__(name=name) + self.num_classes = num_classes + self.average = average + self.true_positives_col = self.add_weight('TP-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_positives_col = self.add_weight('FP-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.false_negatives_col = self.add_weight('FN-class', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + self.weights_intermediate = self.add_weight('weights-int-f1', + shape=[self.num_classes], + initializer='zeros', + dtype=tf.float32) + + def update_state(self, y_true, y_pred): + y_true = tf.cast(y_true, tf.int32) + y_pred = tf.cast(y_pred, tf.int32) + + # true positive across column + self.true_positives_col.assign_add(tf.cast( + tf.math.count_nonzero(y_pred * y_true, axis=0), tf.float32)) + # false positive across column + self.false_positives_col.assign_add( + tf.cast(tf.math.count_nonzero(y_pred * (y_true - 1), axis=0), + tf.float32)) + # false negative across column + self.false_negatives_col.assign_add(tf.cast(tf.math.count_nonzero( + (y_pred - 1) * y_true, axis=0), tf.float32)) + # variable to hold intermediate weights + self.weights_intermediate.assign_add(tf.cast( + tf.reduce_sum(y_true, axis=0), tf.float32)) + + def result(self): + p_sum = tf.cast(self.true_positives_col + self.false_positives_col, + tf.float32) + # calculate precision + precision_macro = tf.cast(tf.math.divide_no_nan( + self.true_positives_col, p_sum), tf.float32) + + r_sum = tf.cast(self.true_positives_col + self.false_negatives_col, + tf.float32) + # calculate recall + recall_macro = tf.cast(tf.math.divide_no_nan( + self.true_positives_col, r_sum), tf.float32) + + mul_value = 2 * precision_macro * recall_macro + add_value = precision_macro + recall_macro + f1_macro_int = tf.cast(tf.math.divide_no_nan(mul_value, add_value), + tf.float32) + # f1 macro score + f1_score = tf.reduce_mean(f1_macro_int) + # condition for weighted f1 score + if self.average == 'weighted': + f1_int_weights = tf.cast(tf.math.divide_no_nan( + self.weights_intermediate, tf.reduce_sum( + self.weights_intermediate)), + tf.float32) + # weighted f1 score calculation + f1_score = tf.reduce_sum(f1_macro_int * f1_int_weights) + + return f1_score + + def reset_states(self): + # reset state of the variables to zero + self.true_positives_col.assign(np.zeros(self.num_classes), np.float32) + self.false_positives_col.assign(np.zeros(self.num_classes), np.float32) + self.false_negatives_col.assign(np.zeros(self.num_classes), np.float32) + self.weights_intermediate.assign(np.zeros(self.num_classes), + np.float32) From 361df849f94b27090dfaf7c4e8a86a7332de1b5a Mon Sep 17 00:00:00 2001 From: saishruthi Date: Fri, 14 Jun 2019 12:16:23 -0700 Subject: [PATCH 3/9] Updates to f1 metric --- tensorflow_addons/metrics/BUILD | 38 +++++++ tensorflow_addons/metrics/README.md | 2 + tensorflow_addons/metrics/__init__.py | 5 +- .../metrics/f1_macro_and_weighted_test.py | 98 +++++++++++++++++++ tensorflow_addons/metrics/f1_micro_test.py | 86 ++++++++++++++++ tensorflow_addons/metrics/f1_scores.py | 69 +++++++------ 6 files changed, 268 insertions(+), 30 deletions(-) create mode 100644 tensorflow_addons/metrics/f1_macro_and_weighted_test.py create mode 100644 tensorflow_addons/metrics/f1_micro_test.py diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index b1208e465e..6c94584b40 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -26,3 +26,41 @@ py_test( ":metrics", ], ) + +py_library( + name = "metrics", + srcs = [ + "__init__.py", + "f1_scores.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow_addons/utils", + ], +) + +py_test( + name = "f1_micro_test", + size = "small", + srcs = [ + "f1_micro_test.py", + ], + main = "f1_micro_test.py", + srcs_version = "PY2AND3", + deps = [ + ":metrics", + ], +) + +py_test( + name = "f1_macro_and_weighted_test", + size = "small", + srcs = [ + "f1_macro_and_weighted_test.py", + ], + main = "f1_macro_and_weighted_test.py", + srcs_version = "PY2AND3", + deps = [ + ":metrics", + ], +) diff --git a/tensorflow_addons/metrics/README.md b/tensorflow_addons/metrics/README.md index 4a978fbe08..0e663bf810 100644 --- a/tensorflow_addons/metrics/README.md +++ b/tensorflow_addons/metrics/README.md @@ -4,11 +4,13 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | cohens_kappa| Aakash Nain | aakashnain@outlook.com| +| f1_scores| Saishruthi Swaminathan | saishruthi.tn@gmail.com| ## Contents | Submodule | Metric | Reference | |:----------------------- |:-------------------|:---------------| | cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)| +| f1_scores| F1 micro, macro and weighted| [F1 Score](https://en.wikipedia.org/wiki/F1_score)| ## Contribution Guidelines diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py index 4610eb870d..e0b912ce0d 100644 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -18,4 +18,7 @@ from __future__ import division from __future__ import print_function -from tensorflow_addons.metrics.cohens_kappa import CohenKappa \ No newline at end of file +from tensorflow_addons.metrics.cohens_kappa import CohenKappa +from tensorflow_addons.metrics.f1_scores import F1Micro +from tensorflow_addons.metrics.f1_scores import F1MacroAndWeighted + diff --git a/tensorflow_addons/metrics/f1_macro_and_weighted_test.py b/tensorflow_addons/metrics/f1_macro_and_weighted_test.py new file mode 100644 index 0000000000..bb6d5f0ff7 --- /dev/null +++ b/tensorflow_addons/metrics/f1_macro_and_weighted_test.py @@ -0,0 +1,98 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for F1 macro and weighted metric.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.metrics import F1MacroAndWeighted +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class F1MacroAndWeightedTest(tf.test.TestCase): + def test_config(self): + f1_obj = F1MacroAndWeighted(name='f1_macro_and_weighted_score', + num_classes=3) + self.assertEqual(f1_obj.name, 'f1_macro_and_weighted_score') + self.assertEqual(f1_obj.dtype, tf.float32) + self.assertEqual(f1_obj.num_classes, 3) + # Check save and restore config + f1_obj2 = F1MacroAndWeighted.from_config(f1_obj.get_config()) + self.assertEqual(f1_obj2.name, 'f1_macro_and_weighted_score') + self.assertEqual(f1_obj2.dtype, tf.float32) + self.assertEqual(f1_obj2.num_classes, 3) + + def initialize_vars(self): + f1_obj = F1MacroAndWeighted(num_classes=3, average='macro') + f1_obj1 = F1MacroAndWeighted(num_classes=3, average='weighted') + + self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables)) + self.evaluate(tf.compat.v1.variables_initializer(f1_obj1.variables)) + return f1_obj, f1_obj1 + + def update_obj_states(self, f1_obj, f1_obj1, actuals, preds): + update_op1 = f1_obj.update_state(actuals, preds) + update_op2 = f1_obj1.update_state(actuals, preds) + self.evaluate(update_op1) + self.evaluate(update_op2) + + def check_results(self, obj, value): + self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) + + def test_f1_perfect_score(self): + actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + preds = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + # Initialize + f1_obj, f1_obj1 = self.initialize_vars() + # Update + self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + # Check results + self.check_results(f1_obj, 1.0) + self.check_results(f1_obj1, 1.0) + + def test_f1_worst_score(self): + actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + preds = [[0, 0, 0], [0, 1, 0], [0, 0, 1]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + # Initialize + f1_obj, f1_obj1 = self.initialize_vars() + # Update + self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + # Check results + self.check_results(f1_obj, 0.0) + self.check_results(f1_obj1, 0.0) + + def test_f1_random_score(self): + actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + preds = [[0, 0, 1], [1, 1, 0], [1, 1, 1]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + # Initialize + f1_obj, f1_obj1 = self.initialize_vars() + # Update + self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + # Check results + self.check_results(f1_obj, 0.6555555) + self.check_results(f1_obj1, 0.6777777) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_addons/metrics/f1_micro_test.py b/tensorflow_addons/metrics/f1_micro_test.py new file mode 100644 index 0000000000..587c07524c --- /dev/null +++ b/tensorflow_addons/metrics/f1_micro_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests F1 micro metric.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.metrics import F1Micro +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class F1MicroTest(tf.test.TestCase): + def test_config(self): + f1_micro_obj = F1Micro(name='f1_micro_score') + self.assertEqual(f1_micro_obj.name, 'f1_micro_score') + self.assertEqual(f1_micro_obj.dtype, tf.float32) + # Check save and restore config + f1_micro_obj2 = F1Micro.from_config(f1_micro_obj.get_config()) + self.assertEqual(f1_micro_obj2.name, 'f1_micro_score') + self.assertEqual(f1_micro_obj2.dtype, tf.float32) + + def initialize_vars(self): + f1_micro_obj = F1Micro() + self.evaluate(tf.compat.v1.variables_initializer( + f1_micro_obj.variables)) + return f1_micro_obj + + def update_obj_states(self, obj, actuals, preds): + update_op = obj.update_state(actuals, preds) + self.evaluate(update_op) + + def check_results(self, obj, value): + self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) + + def test_f1_micro_perfect_score(self): + actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] + preds = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + # Initialize + f1_micro_obj = self.initialize_vars() + # Update + self.update_obj_states(f1_micro_obj, actuals, preds) + # Check results + self.check_results(f1_micro_obj, 1.0) + + def test_f1_micro_worst_score(self): + actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] + preds = [[0, 0, 0], [0, 1, 1], [0, 0, 0]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + f1_micro_obj = self.initialize_vars() + # Update + self.update_obj_states(f1_micro_obj, actuals, preds) + # Check results + self.check_results(f1_micro_obj, 0.0) + + def test_f1_micro_random_score(self): + actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] + preds = [[1, 1, 0], [1, 1, 1], [0, 1, 0]] + actuals = tf.constant(actuals, dtype=tf.float32) + preds = tf.constant(preds, dtype=tf.float32) + f1_micro_obj = self.initialize_vars() + # Update + self.update_obj_states(f1_micro_obj, actuals, preds) + # Check results + self.check_results(f1_micro_obj, 0.7272727) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_addons/metrics/f1_scores.py b/tensorflow_addons/metrics/f1_scores.py index 9a16ce23f0..1ef39f8e62 100644 --- a/tensorflow_addons/metrics/f1_scores.py +++ b/tensorflow_addons/metrics/f1_scores.py @@ -23,19 +23,19 @@ import numpy as np -class F1_micro(Metric): +class F1Micro(Metric): """ Calculates F1 micro score """ - def __init__(self, name='f1-score'): - super(F1_micro, self).__init__(name=name) - self.true_positives = self.add_weight('TP-class', shape=[], + def __init__(self, name='f1_micro_score', dtype=tf.float32): + super(F1Micro, self).__init__(name=name) + self.true_positives = self.add_weight('true_positives', shape=[], initializer='zeros', dtype=tf.float32) - self.false_positives = self.add_weight('FP-class', shape=[], + self.false_positives = self.add_weight('false_positives', shape=[], initializer='zeros', dtype=tf.float32) - self.false_negatives = self.add_weight('FN-class', shape=[], + self.false_negatives = self.add_weight('false_negatives', shape=[], initializer='zeros', dtype=tf.float32) @@ -58,20 +58,18 @@ def result(self): p_sum = tf.cast(self.true_positives + self.false_positives, tf.float32) # precision calculation - precision_micro = tf.cast(tf.math.divide_no_nan( - self.true_positives, p_sum), tf.float32) + precision_micro = tf.math.divide_no_nan(self.true_positives, + p_sum) r_sum = tf.cast(self.true_positives + self.false_negatives, tf.float32) # recall calculation - recall_micro = tf.cast(tf.math.divide_no_nan( - self.true_positives, r_sum), tf.float32) - + recall_micro = tf.math.divide_no_nan(self.true_positives, + r_sum) + # f1 micro score calculation mul_value = 2 * precision_micro * recall_micro add_value = precision_micro + recall_micro - f1_micro = tf.cast(tf.math.divide_no_nan(mul_value, add_value), - tf.float32) - # f1 score calculation + f1_micro = tf.math.divide_no_nan(mul_value, add_value) f1_micro = tf.reduce_mean(f1_micro) return f1_micro @@ -83,28 +81,33 @@ def reset_states(self): self.false_negatives.assign(0) -class F1_macro_and_weighted(Metric): +class F1MacroAndWeighted(Metric): """ Calculates F1 macro or weighted based on the user's choice """ - def __init__(self, num_classes, average, - name='f1-macro-and-weighted-score'): - super(F1_macro_and_weighted, self).__init__(name=name) + + def __init__(self, num_classes, average=None, + name='f1_macro_and_weighted_score', dtype=tf.float32): + super(F1MacroAndWeighted, self).__init__(name=name) self.num_classes = num_classes - self.average = average - self.true_positives_col = self.add_weight('TP-class', + if average not in (None, 'macro', 'weighted'): + raise ValueError("Unknown average type. Acceptable values " + "are: [macro, weighted]") + else: + self.average = average + self.true_positives_col = self.add_weight('true_positives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.false_positives_col = self.add_weight('FP-class', + self.false_positives_col = self.add_weight('false_positives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.false_negatives_col = self.add_weight('FN-class', + self.false_negatives_col = self.add_weight('false_negatives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.weights_intermediate = self.add_weight('weights-int-f1', + self.weights_intermediate = self.add_weight('weights', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) @@ -131,19 +134,18 @@ def result(self): p_sum = tf.cast(self.true_positives_col + self.false_positives_col, tf.float32) # calculate precision - precision_macro = tf.cast(tf.math.divide_no_nan( - self.true_positives_col, p_sum), tf.float32) + precision_macro = tf.math.divide_no_nan(self.true_positives_col, + p_sum) r_sum = tf.cast(self.true_positives_col + self.false_negatives_col, tf.float32) # calculate recall - recall_macro = tf.cast(tf.math.divide_no_nan( - self.true_positives_col, r_sum), tf.float32) + recall_macro = tf.math.divide_no_nan(self.true_positives_col, + r_sum) mul_value = 2 * precision_macro * recall_macro add_value = precision_macro + recall_macro - f1_macro_int = tf.cast(tf.math.divide_no_nan(mul_value, add_value), - tf.float32) + f1_macro_int = tf.math.divide_no_nan(mul_value, add_value) # f1 macro score f1_score = tf.reduce_mean(f1_macro_int) # condition for weighted f1 score @@ -157,6 +159,15 @@ def result(self): return f1_score + def get_config(self): + """Returns the serializable config of the metric.""" + config = { + "num_classes": self.num_classes, + "average": self.average, + } + base_config = super(F1MacroAndWeighted, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + def reset_states(self): # reset state of the variables to zero self.true_positives_col.assign(np.zeros(self.num_classes), np.float32) From a88c5981660eb912e8afd3e419370d84c442c533 Mon Sep 17 00:00:00 2001 From: saishruthi Date: Sun, 16 Jun 2019 14:03:44 -0700 Subject: [PATCH 4/9] new updates --- tensorflow_addons/metrics/BUILD | 28 +--- tensorflow_addons/metrics/__init__.py | 3 +- tensorflow_addons/metrics/f1_micro_test.py | 86 ---------- tensorflow_addons/metrics/f1_scores.py | 155 +++++++----------- ..._macro_and_weighted_test.py => f1_test.py} | 47 +++--- 5 files changed, 91 insertions(+), 228 deletions(-) delete mode 100644 tensorflow_addons/metrics/f1_micro_test.py rename tensorflow_addons/metrics/{f1_macro_and_weighted_test.py => f1_test.py} (66%) diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index 6c94584b40..58874a8bc1 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -7,6 +7,7 @@ py_library( srcs = [ "__init__.py", "cohens_kappa.py", + "f1_scores.py", ], srcs_version = "PY2AND3", deps = [ @@ -27,23 +28,11 @@ py_test( ], ) -py_library( - name = "metrics", - srcs = [ - "__init__.py", - "f1_scores.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow_addons/utils", - ], -) - py_test( name = "f1_micro_test", size = "small", srcs = [ - "f1_micro_test.py", + "f1_test.py", ], main = "f1_micro_test.py", srcs_version = "PY2AND3", @@ -51,16 +40,3 @@ py_test( ":metrics", ], ) - -py_test( - name = "f1_macro_and_weighted_test", - size = "small", - srcs = [ - "f1_macro_and_weighted_test.py", - ], - main = "f1_macro_and_weighted_test.py", - srcs_version = "PY2AND3", - deps = [ - ":metrics", - ], -) diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py index e0b912ce0d..f6df5a27fe 100644 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -19,6 +19,5 @@ from __future__ import print_function from tensorflow_addons.metrics.cohens_kappa import CohenKappa -from tensorflow_addons.metrics.f1_scores import F1Micro -from tensorflow_addons.metrics.f1_scores import F1MacroAndWeighted +from tensorflow_addons.metrics.f1_scores import F1Score diff --git a/tensorflow_addons/metrics/f1_micro_test.py b/tensorflow_addons/metrics/f1_micro_test.py deleted file mode 100644 index 587c07524c..0000000000 --- a/tensorflow_addons/metrics/f1_micro_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests F1 micro metric.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow_addons.metrics import F1Micro -from tensorflow_addons.utils import test_utils - - -@test_utils.run_all_in_graph_and_eager_modes -class F1MicroTest(tf.test.TestCase): - def test_config(self): - f1_micro_obj = F1Micro(name='f1_micro_score') - self.assertEqual(f1_micro_obj.name, 'f1_micro_score') - self.assertEqual(f1_micro_obj.dtype, tf.float32) - # Check save and restore config - f1_micro_obj2 = F1Micro.from_config(f1_micro_obj.get_config()) - self.assertEqual(f1_micro_obj2.name, 'f1_micro_score') - self.assertEqual(f1_micro_obj2.dtype, tf.float32) - - def initialize_vars(self): - f1_micro_obj = F1Micro() - self.evaluate(tf.compat.v1.variables_initializer( - f1_micro_obj.variables)) - return f1_micro_obj - - def update_obj_states(self, obj, actuals, preds): - update_op = obj.update_state(actuals, preds) - self.evaluate(update_op) - - def check_results(self, obj, value): - self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) - - def test_f1_micro_perfect_score(self): - actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] - preds = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) - # Initialize - f1_micro_obj = self.initialize_vars() - # Update - self.update_obj_states(f1_micro_obj, actuals, preds) - # Check results - self.check_results(f1_micro_obj, 1.0) - - def test_f1_micro_worst_score(self): - actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] - preds = [[0, 0, 0], [0, 1, 1], [0, 0, 0]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) - f1_micro_obj = self.initialize_vars() - # Update - self.update_obj_states(f1_micro_obj, actuals, preds) - # Check results - self.check_results(f1_micro_obj, 0.0) - - def test_f1_micro_random_score(self): - actuals = [[1, 1, 0], [1, 0, 0], [1, 1, 0]] - preds = [[1, 1, 0], [1, 1, 1], [0, 1, 0]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) - f1_micro_obj = self.initialize_vars() - # Update - self.update_obj_states(f1_micro_obj, actuals, preds) - # Check results - self.check_results(f1_micro_obj, 0.7272727) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow_addons/metrics/f1_scores.py b/tensorflow_addons/metrics/f1_scores.py index 1ef39f8e62..59ae4cda8d 100644 --- a/tensorflow_addons/metrics/f1_scores.py +++ b/tensorflow_addons/metrics/f1_scores.py @@ -23,124 +23,85 @@ import numpy as np -class F1Micro(Metric): +class F1Score(Metric): """ - Calculates F1 micro score - """ - def __init__(self, name='f1_micro_score', dtype=tf.float32): - super(F1Micro, self).__init__(name=name) - self.true_positives = self.add_weight('true_positives', shape=[], - initializer='zeros', - dtype=tf.float32) - self.false_positives = self.add_weight('false_positives', shape=[], - initializer='zeros', - dtype=tf.float32) - self.false_negatives = self.add_weight('false_negatives', shape=[], - initializer='zeros', - dtype=tf.float32) - - def update_state(self, y_true, y_pred): - y_true = tf.cast(y_true, tf.int32) - y_pred = tf.cast(y_pred, tf.int32) - - # true positive across column - self.true_positives.assign_add(tf.cast(tf.math.count_nonzero( - y_pred * y_true, axis=None), tf.float32)) - # false positive across column - self.false_positives.assign_add(tf.cast(tf.math.count_nonzero( - y_pred * (y_true - 1), axis=None), tf.float32)) - # false negative across column - self.false_negatives.assign_add(tf.cast( - tf.math.count_nonzero((y_pred - 1) * y_true, axis=None), - tf.float32)) - - def result(self): - p_sum = tf.cast(self.true_positives + self.false_positives, - tf.float32) - # precision calculation - precision_micro = tf.math.divide_no_nan(self.true_positives, - p_sum) - - r_sum = tf.cast(self.true_positives + self.false_negatives, - tf.float32) - # recall calculation - recall_micro = tf.math.divide_no_nan(self.true_positives, - r_sum) - # f1 micro score calculation - mul_value = 2 * precision_micro * recall_micro - add_value = precision_micro + recall_micro - f1_micro = tf.math.divide_no_nan(mul_value, add_value) - f1_micro = tf.reduce_mean(f1_micro) - - return f1_micro - - def reset_states(self): - # reset state of the variables to zero - self.true_positives.assign(0) - self.false_positives.assign(0) - self.false_negatives.assign(0) - - -class F1MacroAndWeighted(Metric): - """ - Calculates F1 macro or weighted based on the user's choice + Calculates F1 micro, macro or weighted based on the + user's choice """ def __init__(self, num_classes, average=None, - name='f1_macro_and_weighted_score', dtype=tf.float32): - super(F1MacroAndWeighted, self).__init__(name=name) + name='f1_score', dtype=tf.float32): + super(F1Score, self).__init__(name=name) self.num_classes = num_classes - if average not in (None, 'macro', 'weighted'): + if average not in (None, 'micro', 'macro', 'weighted'): raise ValueError("Unknown average type. Acceptable values " - "are: [macro, weighted]") + "are: [micro, macro, weighted]") else: self.average = average - self.true_positives_col = self.add_weight('true_positives', + if self.average == 'micro': + self.axis = None + if self.average == 'macro' or self.average == 'weighted': + self.axis = 0 + if self.average == 'micro': + self.true_positives = self.add_weight('true_positives', + shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_positives = self.add_weight('false_positives', + shape=[], + initializer='zeros', + dtype=tf.float32) + self.false_negatives = self.add_weight('false_negatives', + shape=[], + initializer='zeros', + dtype=tf.float32) + else: + self.true_positives = self.add_weight('true_positives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.false_positives_col = self.add_weight('false_positives', + self.false_positives = self.add_weight('false_positives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.false_negatives_col = self.add_weight('false_negatives', + self.false_negatives = self.add_weight('false_negatives', shape=[self.num_classes], initializer='zeros', dtype=tf.float32) - self.weights_intermediate = self.add_weight('weights', - shape=[self.num_classes], - initializer='zeros', - dtype=tf.float32) + self.weights_intermediate = self.add_weight('weights', shape=[ + self.num_classes], initializer='zeros', dtype=tf.float32) def update_state(self, y_true, y_pred): y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) - # true positive across column - self.true_positives_col.assign_add(tf.cast( - tf.math.count_nonzero(y_pred * y_true, axis=0), tf.float32)) - # false positive across column - self.false_positives_col.assign_add( - tf.cast(tf.math.count_nonzero(y_pred * (y_true - 1), axis=0), - tf.float32)) - # false negative across column - self.false_negatives_col.assign_add(tf.cast(tf.math.count_nonzero( - (y_pred - 1) * y_true, axis=0), tf.float32)) - # variable to hold intermediate weights - self.weights_intermediate.assign_add(tf.cast( - tf.reduce_sum(y_true, axis=0), tf.float32)) + # true positive + self.true_positives.assign_add(tf.cast( + tf.math.count_nonzero(y_pred * y_true, axis=self.axis), + tf.float32)) + # false positive + self.false_positives.assign_add( + tf.cast(tf.math.count_nonzero(y_pred * (y_true - 1), + axis=self.axis), tf.float32)) + # false negative + self.false_negatives.assign_add(tf.cast(tf.math.count_nonzero( + (y_pred - 1) * y_true, axis=self.axis), tf.float32)) + if self.average == 'weighted': + # variable to hold intermediate weights + self.weights_intermediate.assign_add(tf.cast( + tf.reduce_sum(y_true, axis=self.axis), tf.float32)) def result(self): - p_sum = tf.cast(self.true_positives_col + self.false_positives_col, + p_sum = tf.cast(self.true_positives + self.false_positives, tf.float32) # calculate precision - precision_macro = tf.math.divide_no_nan(self.true_positives_col, + precision_macro = tf.math.divide_no_nan(self.true_positives, p_sum) - r_sum = tf.cast(self.true_positives_col + self.false_negatives_col, + r_sum = tf.cast(self.true_positives + self.false_negatives, tf.float32) # calculate recall - recall_macro = tf.math.divide_no_nan(self.true_positives_col, + recall_macro = tf.math.divide_no_nan(self.true_positives, r_sum) mul_value = 2 * precision_macro * recall_macro @@ -161,17 +122,23 @@ def result(self): def get_config(self): """Returns the serializable config of the metric.""" + config = { "num_classes": self.num_classes, "average": self.average, } - base_config = super(F1MacroAndWeighted, self).get_config() + base_config = super(F1Score, self).get_config() return dict(list(base_config.items()) + list(config.items())) def reset_states(self): # reset state of the variables to zero - self.true_positives_col.assign(np.zeros(self.num_classes), np.float32) - self.false_positives_col.assign(np.zeros(self.num_classes), np.float32) - self.false_negatives_col.assign(np.zeros(self.num_classes), np.float32) - self.weights_intermediate.assign(np.zeros(self.num_classes), - np.float32) + if self.average == 'micro': + self.true_positives.assign(0) + self.false_positives.assign(0) + self.false_negatives.assign(0) + else: + self.true_positives.assign(np.zeros(self.num_classes), np.float32) + self.false_positives.assign(np.zeros(self.num_classes), np.float32) + self.false_negatives.assign(np.zeros(self.num_classes), np.float32) + self.weights_intermediate.assign(np.zeros(self.num_classes), + np.float32) diff --git a/tensorflow_addons/metrics/f1_macro_and_weighted_test.py b/tensorflow_addons/metrics/f1_test.py similarity index 66% rename from tensorflow_addons/metrics/f1_macro_and_weighted_test.py rename to tensorflow_addons/metrics/f1_test.py index bb6d5f0ff7..b1b4472171 100644 --- a/tensorflow_addons/metrics/f1_macro_and_weighted_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -12,44 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for F1 macro and weighted metric.""" +"""Tests F1 micro metric.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf -from tensorflow_addons.metrics import F1MacroAndWeighted +from tensorflow_addons.metrics import F1Score from tensorflow_addons.utils import test_utils @test_utils.run_all_in_graph_and_eager_modes -class F1MacroAndWeightedTest(tf.test.TestCase): +class F1ScoreTest(tf.test.TestCase): def test_config(self): - f1_obj = F1MacroAndWeighted(name='f1_macro_and_weighted_score', - num_classes=3) - self.assertEqual(f1_obj.name, 'f1_macro_and_weighted_score') + f1_obj = F1Score(name='f1_score', + num_classes=3) + self.assertEqual(f1_obj.name, 'f1_score') self.assertEqual(f1_obj.dtype, tf.float32) self.assertEqual(f1_obj.num_classes, 3) # Check save and restore config - f1_obj2 = F1MacroAndWeighted.from_config(f1_obj.get_config()) - self.assertEqual(f1_obj2.name, 'f1_macro_and_weighted_score') + f1_obj2 = F1Score.from_config(f1_obj.get_config()) + self.assertEqual(f1_obj2.name, 'f1_score') self.assertEqual(f1_obj2.dtype, tf.float32) self.assertEqual(f1_obj2.num_classes, 3) def initialize_vars(self): - f1_obj = F1MacroAndWeighted(num_classes=3, average='macro') - f1_obj1 = F1MacroAndWeighted(num_classes=3, average='weighted') + f1_obj = F1Score(num_classes=3, average='micro') + f1_obj1 = F1Score(num_classes=3, average='macro') + f1_obj2 = F1Score(num_classes=3, average='weighted') self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables)) self.evaluate(tf.compat.v1.variables_initializer(f1_obj1.variables)) - return f1_obj, f1_obj1 + self.evaluate(tf.compat.v1.variables_initializer(f1_obj2.variables)) + return f1_obj, f1_obj1, f1_obj2 - def update_obj_states(self, f1_obj, f1_obj1, actuals, preds): + def update_obj_states(self, f1_obj, f1_obj1, f1_obj2, actuals, preds): update_op1 = f1_obj.update_state(actuals, preds) update_op2 = f1_obj1.update_state(actuals, preds) + update_op3 = f1_obj2.update_state(actuals, preds) self.evaluate(update_op1) self.evaluate(update_op2) + self.evaluate(update_op3) def check_results(self, obj, value): self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) @@ -60,12 +64,13 @@ def test_f1_perfect_score(self): actuals = tf.constant(actuals, dtype=tf.float32) preds = tf.constant(preds, dtype=tf.float32) # Initialize - f1_obj, f1_obj1 = self.initialize_vars() + f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) # Check results self.check_results(f1_obj, 1.0) self.check_results(f1_obj1, 1.0) + self.check_results(f1_obj2, 1.0) def test_f1_worst_score(self): actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] @@ -73,12 +78,13 @@ def test_f1_worst_score(self): actuals = tf.constant(actuals, dtype=tf.float32) preds = tf.constant(preds, dtype=tf.float32) # Initialize - f1_obj, f1_obj1 = self.initialize_vars() + f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) # Check results self.check_results(f1_obj, 0.0) self.check_results(f1_obj1, 0.0) + self.check_results(f1_obj2, 0.0) def test_f1_random_score(self): actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] @@ -86,12 +92,13 @@ def test_f1_random_score(self): actuals = tf.constant(actuals, dtype=tf.float32) preds = tf.constant(preds, dtype=tf.float32) # Initialize - f1_obj, f1_obj1 = self.initialize_vars() + f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, actuals, preds) + self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) # Check results - self.check_results(f1_obj, 0.6555555) - self.check_results(f1_obj1, 0.6777777) + self.check_results(f1_obj, 0.6666666) + self.check_results(f1_obj1, 0.6555555) + self.check_results(f1_obj2, 0.6777777) if __name__ == '__main__': From 66da592a6b8fcbb9808679ea81bec37893780372 Mon Sep 17 00:00:00 2001 From: saishruthi Date: Sun, 16 Jun 2019 16:12:31 -0700 Subject: [PATCH 5/9] updates --- tensorflow_addons/metrics/f1_scores.py | 87 +++++++++++++++++++++----- tensorflow_addons/metrics/f1_test.py | 43 +++++++++---- 2 files changed, 102 insertions(+), 28 deletions(-) diff --git a/tensorflow_addons/metrics/f1_scores.py b/tensorflow_addons/metrics/f1_scores.py index 59ae4cda8d..a789d70d43 100644 --- a/tensorflow_addons/metrics/f1_scores.py +++ b/tensorflow_addons/metrics/f1_scores.py @@ -26,7 +26,60 @@ class F1Score(Metric): """ Calculates F1 micro, macro or weighted based on the - user's choice + user's choice. + + F1 score is the weighted average of precision and + recall. Output range is [0, 1]. This works for both + multi-class and multi-label classification. + + + Args: + num_classes : Number of unique classes in the dataset. + average : Type of averaging to be performed on data. + Acceptable values are None, micro, macro and + weighted. + Default value is None. + + Returns: + F1 score: float + + Raises: + ValueError: If the `average` has values other than + [None, micro, macro. weighted]. + + `average` parameter behavior: + + 1. If `None` is specified as an input, scores for each + class are returned. + + 2. If `micro` is specified, metrics like true positivies, + false positives and false negatives are computed + globally. + + 3. If `macro` is specified, metrics like true positivies, + false positives and false negatives are computed for + each class and their unweighted mean is returned. + Imbalance in dataset is not taken into account for + calculating the score + + 4. If `weighted` is specified, metrics are computed for + each class and returns the mean weighted by the + number of true instances in each class taking data + imbalance into account. + + Usage: + ```python + actuals = tf.constant([[1, 1, 0],[1, 0, 0]], + dtype=tf.int32) + preds = tf.constant([[1, 0, 0],[1, 0, 1]], + dtype=tf.int32) + output = tf.keras.metrics.F1Score(num_classes=3, + average='micro') + output.update_state(actuals, predictions) + print('F1 Micro score is: ', + output.result().numpy()) # 0.6666667 + ``` + """ def __init__(self, num_classes, average=None, @@ -40,7 +93,7 @@ def __init__(self, num_classes, average=None, self.average = average if self.average == 'micro': self.axis = None - if self.average == 'macro' or self.average == 'weighted': + else: self.axis = 0 if self.average == 'micro': self.true_positives = self.add_weight('true_positives', @@ -95,28 +148,30 @@ def result(self): p_sum = tf.cast(self.true_positives + self.false_positives, tf.float32) # calculate precision - precision_macro = tf.math.divide_no_nan(self.true_positives, - p_sum) + precision = tf.math.divide_no_nan(self.true_positives, + p_sum) r_sum = tf.cast(self.true_positives + self.false_negatives, tf.float32) # calculate recall - recall_macro = tf.math.divide_no_nan(self.true_positives, - r_sum) - - mul_value = 2 * precision_macro * recall_macro - add_value = precision_macro + recall_macro - f1_macro_int = tf.math.divide_no_nan(mul_value, add_value) - # f1 macro score - f1_score = tf.reduce_mean(f1_macro_int) + recall = tf.math.divide_no_nan(self.true_positives, + r_sum) + + mul_value = 2 * precision * recall + add_value = precision + recall + f1_int = tf.math.divide_no_nan(mul_value, add_value) + # f1 score + if self.average is not None: + f1_score = tf.reduce_mean(f1_int) + else: + f1_score = f1_int # condition for weighted f1 score if self.average == 'weighted': - f1_int_weights = tf.cast(tf.math.divide_no_nan( + f1_int_weights = tf.math.divide_no_nan( self.weights_intermediate, tf.reduce_sum( - self.weights_intermediate)), - tf.float32) + self.weights_intermediate)) # weighted f1 score calculation - f1_score = tf.reduce_sum(f1_macro_int * f1_int_weights) + f1_score = tf.reduce_sum(f1_int * f1_int_weights) return f1_score diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py index b1b4472171..617c1340aa 100644 --- a/tensorflow_addons/metrics/f1_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -47,6 +47,12 @@ def initialize_vars(self): self.evaluate(tf.compat.v1.variables_initializer(f1_obj2.variables)) return f1_obj, f1_obj1, f1_obj2 + def initialize_vars_none(self): + f1_none = F1Score(num_classes=3, average=None) + + self.evaluate(tf.compat.v1.variables_initializer(f1_none.variables)) + return f1_none + def update_obj_states(self, f1_obj, f1_obj1, f1_obj2, actuals, preds): update_op1 = f1_obj.update_state(actuals, preds) update_op2 = f1_obj1.update_state(actuals, preds) @@ -55,14 +61,17 @@ def update_obj_states(self, f1_obj, f1_obj1, f1_obj2, actuals, preds): self.evaluate(update_op2) self.evaluate(update_op3) + def update_obj_states_none(self, f1_none, actuals, preds): + update_op1_none = f1_none.update_state(actuals, preds) + self.evaluate(update_op1_none) + def check_results(self, obj, value): self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) def test_f1_perfect_score(self): - actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - preds = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) # Initialize f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update @@ -73,10 +82,9 @@ def test_f1_perfect_score(self): self.check_results(f1_obj2, 1.0) def test_f1_worst_score(self): - actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - preds = [[0, 0, 0], [0, 1, 0], [0, 0, 1]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) # Initialize f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update @@ -87,10 +95,9 @@ def test_f1_worst_score(self): self.check_results(f1_obj2, 0.0) def test_f1_random_score(self): - actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - preds = [[0, 0, 1], [1, 1, 0], [1, 1, 1]] - actuals = tf.constant(actuals, dtype=tf.float32) - preds = tf.constant(preds, dtype=tf.float32) + actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.int32) + preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) # Initialize f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() # Update @@ -100,6 +107,18 @@ def test_f1_random_score(self): self.check_results(f1_obj1, 0.6555555) self.check_results(f1_obj2, 0.6777777) + def test_f1_none_score(self): + actuals = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], + [0, 1, 0], [0, 0, 1]], dtype=tf.int32) + preds = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], + [1, 0, 0], [0, 0, 1]], dtype=tf.int32) + # Initialize + f1_none = self.initialize_vars_none() + # Update + self.update_obj_states_none(f1_none, actuals, preds) + # Check results + self.check_results(f1_none, [0.8, 0.6666667, 1.]) + if __name__ == '__main__': tf.test.main() From bba550b8388790b893c8e6d846cebf60ef9550b3 Mon Sep 17 00:00:00 2001 From: saishruthi Date: Sun, 16 Jun 2019 16:19:32 -0700 Subject: [PATCH 6/9] updates --- .DS_Store | Bin 0 -> 6148 bytes tensorflow_addons/.DS_Store | Bin 0 -> 6148 bytes tensorflow_addons/metrics/__init__.py | 1 - tensorflow_addons/metrics/f1_scores.py | 115 +++++++++++++------------ tensorflow_addons/metrics/f1_test.py | 13 +-- 5 files changed, 69 insertions(+), 60 deletions(-) create mode 100644 .DS_Store create mode 100644 tensorflow_addons/.DS_Store mode change 100644 => 100755 tensorflow_addons/metrics/__init__.py mode change 100644 => 100755 tensorflow_addons/metrics/f1_scores.py mode change 100644 => 100755 tensorflow_addons/metrics/f1_test.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..dec5535e12e92e5d7497136f5b0e161b6be2b11b GIT binary patch literal 6148 zcmeHKQE$^Q5I(ov<`$t9X;OLg3vYdB5v<+Aq^K(oyded_1E4NZ3XPP-Rg)5JT~K}o z{3HIs{?6`fx2P&eynraWi_YKqd}qg>BRd8FF zHx3Suq-@F7!`26XrVBq0@?|y(7VoI`D#&M{vtMhSZ0sO>6Hkjt`{Y8Wc@U@5g-MR1 zDTcg%8>f*jMtYe>nd!&34N^8`bJFgtR>Nn5t~wo*U9}qYhF$gi>}gpxW#>u%{N?4< z_08(G{PNY(E%3`Ua_n#o-_ZDtnSb^dX`<6l7`^Uk-2)>tzzi@0`^tdZ&*X!BodCZ# zGr$b|F$QRVP^g5S!^)!BIS5T0$TZYp9ALOt%qTMt##>Olxq58i}`9#k~3MFTaFCbej+n0HWDztpZd4K%x>B9BjT3ijz)A!FmXV9wUS+NP>QH z7j!1F+3_D4pl>$^x6ps=?nSLckK>tu zl#X;RD%C98T5r_UW@AuOS);yHQybfBgMlNfE2}#Pm$wg(+0)?p#ZW8oFVwPVaR{$y zd|Bna*Nr-Hbd8Z=)i9fp8DIvOfoW&J?t5~6+84^}V+NRkUo$}agF+>AE#?OG)`1OO zA1PiUBte_r5`@a4YcV&75fouc5lyMEPYhwo(XVWtYcV%y%0cLvaU46durCy$XGg!% z=^$K#+%f~qz&Ha%)2-0?zyI_5f4qo$%m6bmRSbyIp10S;CE2reVsUiVD%49<63WXB lj!MwbM=|EoQM`<*1^tQ)MAu?&5Irb-5zsVn!wmc>10R Date: Sun, 16 Jun 2019 16:33:55 -0700 Subject: [PATCH 7/9] removing ds_store --- .DS_Store | Bin 6148 -> 0 bytes tensorflow_addons/.DS_Store | Bin 6148 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 tensorflow_addons/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index dec5535e12e92e5d7497136f5b0e161b6be2b11b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKQE$^Q5I(ov<`$t9X;OLg3vYdB5v<+Aq^K(oyded_1E4NZ3XPP-Rg)5JT~K}o z{3HIs{?6`fx2P&eynraWi_YKqd}qg>BRd8FF zHx3Suq-@F7!`26XrVBq0@?|y(7VoI`D#&M{vtMhSZ0sO>6Hkjt`{Y8Wc@U@5g-MR1 zDTcg%8>f*jMtYe>nd!&34N^8`bJFgtR>Nn5t~wo*U9}qYhF$gi>}gpxW#>u%{N?4< z_08(G{PNY(E%3`Ua_n#o-_ZDtnSb^dX`<6l7`^Uk-2)>tzzi@0`^tdZ&*X!BodCZ# zGr$b|F$QRVP^g5S!^)!BIS5T0$TZYp9ALOt%qTMt##>Olxq58i}`9#k~3MFTaFCbej+n0HWDztpZd4K%x>B9BjT3ijz)A!FmXV9wUS+NP>QH z7j!1F+3_D4pl>$^x6ps=?nSLckK>tu zl#X;RD%C98T5r_UW@AuOS);yHQybfBgMlNfE2}#Pm$wg(+0)?p#ZW8oFVwPVaR{$y zd|Bna*Nr-Hbd8Z=)i9fp8DIvOfoW&J?t5~6+84^}V+NRkUo$}agF+>AE#?OG)`1OO zA1PiUBte_r5`@a4YcV&75fouc5lyMEPYhwo(XVWtYcV%y%0cLvaU46durCy$XGg!% z=^$K#+%f~qz&Ha%)2-0?zyI_5f4qo$%m6bmRSbyIp10S;CE2reVsUiVD%49<63WXB lj!MwbM=|EoQM`<*1^tQ)MAu?&5Irb-5zsVn!wmc>10R Date: Mon, 17 Jun 2019 10:21:09 -0700 Subject: [PATCH 8/9] minor updates --- tensorflow_addons/metrics/BUILD | 13 +++++++++++++ tensorflow_addons/metrics/f1_test.py | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index 1693064862..f176a95f67 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -41,3 +41,16 @@ py_test( ":metrics", ], ) + +py_test( + name = "f1_test", + size = "small", + srcs = [ + "f1_test.py", + ], + main = "f1_test.py", + srcs_version = "PY2AND3", + deps = [ + ":metrics", + ], +) diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py index dbaa7da528..db0c71c4bd 100755 --- a/tensorflow_addons/metrics/f1_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests F1 micro metric.""" +"""Tests F1 metrics.""" from __future__ import absolute_import from __future__ import division From 9bf94bd8a40e7b3bdd5e5d4c25c03ac7525a613d Mon Sep 17 00:00:00 2001 From: saishruthi Date: Tue, 18 Jun 2019 00:02:26 -0700 Subject: [PATCH 9/9] name updates --- tensorflow_addons/metrics/f1_test.py | 64 ++++++++++++++-------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py index db0c71c4bd..323b8432f9 100755 --- a/tensorflow_addons/metrics/f1_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -37,14 +37,15 @@ def test_config(self): self.assertEqual(f1_obj2.num_classes, 3) def initialize_vars(self): - f1_obj = F1Score(num_classes=3, average='micro') - f1_obj1 = F1Score(num_classes=3, average='macro') - f1_obj2 = F1Score(num_classes=3, average='weighted') + f1_micro = F1Score(num_classes=3, average='micro') + f1_macro = F1Score(num_classes=3, average='macro') + f1_weighted = F1Score(num_classes=3, average='weighted') - self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables)) - self.evaluate(tf.compat.v1.variables_initializer(f1_obj1.variables)) - self.evaluate(tf.compat.v1.variables_initializer(f1_obj2.variables)) - return f1_obj, f1_obj1, f1_obj2 + self.evaluate(tf.compat.v1.variables_initializer(f1_micro.variables)) + self.evaluate(tf.compat.v1.variables_initializer(f1_macro.variables)) + self.evaluate( + tf.compat.v1.variables_initializer(f1_weighted.variables)) + return f1_micro, f1_macro, f1_weighted def initialize_vars_none(self): f1_none = F1Score(num_classes=3, average=None) @@ -52,17 +53,18 @@ def initialize_vars_none(self): self.evaluate(tf.compat.v1.variables_initializer(f1_none.variables)) return f1_none - def update_obj_states(self, f1_obj, f1_obj1, f1_obj2, actuals, preds): - update_op1 = f1_obj.update_state(actuals, preds) - update_op2 = f1_obj1.update_state(actuals, preds) - update_op3 = f1_obj2.update_state(actuals, preds) - self.evaluate(update_op1) - self.evaluate(update_op2) - self.evaluate(update_op3) + def update_obj_states(self, f1_micro, f1_macro, f1_weighted, actuals, + preds): + update_micro = f1_micro.update_state(actuals, preds) + update_macro = f1_macro.update_state(actuals, preds) + update_weighted = f1_weighted.update_state(actuals, preds) + self.evaluate(update_micro) + self.evaluate(update_macro) + self.evaluate(update_weighted) def update_obj_states_none(self, f1_none, actuals, preds): - update_op1_none = f1_none.update_state(actuals, preds) - self.evaluate(update_op1_none) + update_none = f1_none.update_state(actuals, preds) + self.evaluate(update_none) def check_results(self, obj, value): self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5) @@ -72,39 +74,39 @@ def test_f1_perfect_score(self): dtype=tf.int32) preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) # Initialize - f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() + f1_micro, f1_macro, f1_weighted = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) # Check results - self.check_results(f1_obj, 1.0) - self.check_results(f1_obj1, 1.0) - self.check_results(f1_obj2, 1.0) + self.check_results(f1_micro, 1.0) + self.check_results(f1_macro, 1.0) + self.check_results(f1_weighted, 1.0) def test_f1_worst_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) # Initialize - f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() + f1_micro, f1_macro, f1_weighted = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) # Check results - self.check_results(f1_obj, 0.0) - self.check_results(f1_obj1, 0.0) - self.check_results(f1_obj2, 0.0) + self.check_results(f1_micro, 0.0) + self.check_results(f1_macro, 0.0) + self.check_results(f1_weighted, 0.0) def test_f1_random_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) # Initialize - f1_obj, f1_obj1, f1_obj2 = self.initialize_vars() + f1_micro, f1_macro, f1_weighted = self.initialize_vars() # Update - self.update_obj_states(f1_obj, f1_obj1, f1_obj2, actuals, preds) + self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds) # Check results - self.check_results(f1_obj, 0.6666666) - self.check_results(f1_obj1, 0.6555555) - self.check_results(f1_obj2, 0.6777777) + self.check_results(f1_micro, 0.6666666) + self.check_results(f1_macro, 0.6555555) + self.check_results(f1_weighted, 0.6777777) def test_f1_none_score(self): actuals = tf.constant(