From afa8c504565686535c563394c0d6d233ebaf9f92 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Wed, 11 Sep 2019 19:56:38 +0530 Subject: [PATCH 1/8] CLN: Refactor f_scores and f_test * Add `threshold` param to f-scores * Tests now compare with sklearn * Add sklearn to requirements --- tensorflow_addons/metrics/BUILD | 19 +- tensorflow_addons/metrics/f1_test.py | 130 -------- tensorflow_addons/metrics/f_scores.py | 387 +++++++++--------------- tensorflow_addons/metrics/f_test.py | 109 +++++++ tensorflow_addons/metrics/fbeta_test.py | 136 --------- tensorflow_addons/metrics/utils.py | 16 + 6 files changed, 268 insertions(+), 529 deletions(-) delete mode 100755 tensorflow_addons/metrics/f1_test.py create mode 100644 tensorflow_addons/metrics/f_test.py delete mode 100644 tensorflow_addons/metrics/fbeta_test.py diff --git a/tensorflow_addons/metrics/BUILD b/tensorflow_addons/metrics/BUILD index 33db5e905f..00b66afa9f 100644 --- a/tensorflow_addons/metrics/BUILD +++ b/tensorflow_addons/metrics/BUILD @@ -45,25 +45,12 @@ py_test( ) py_test( - name = "f1_test", + name = "f_test", size = "small", srcs = [ - "f1_test.py", + "f_test.py", ], - main = "f1_test.py", - srcs_version = "PY2AND3", - deps = [ - ":metrics", - ], -) - -py_test( - name = "fbeta_test", - size = "small", - srcs = [ - "fbeta_test.py", - ], - main = "fbeta_test.py", + main = "f_test.py", srcs_version = "PY2AND3", deps = [ ":metrics", diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py deleted file mode 100755 index e11165bb2c..0000000000 --- a/tensorflow_addons/metrics/f1_test.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests F1 metrics.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow_addons.metrics import F1Score -from tensorflow_addons.utils import test_utils -from tensorflow.keras import layers -import numpy as np - - -@test_utils.run_all_in_graph_and_eager_modes -class F1ScoreTest(tf.test.TestCase): - def test_config(self): - f1_obj = F1Score(num_classes=3, average=None) - self.assertEqual(f1_obj.name, 'f1_score') - self.assertEqual(f1_obj.dtype, tf.float32) - self.assertEqual(f1_obj.num_classes, 3) - self.assertEqual(f1_obj.average, None) - # Check save and restore config - f1_obj2 = F1Score.from_config(f1_obj.get_config()) - self.assertEqual(f1_obj2.name, 'f1_score') - self.assertEqual(f1_obj2.dtype, tf.float32) - self.assertEqual(f1_obj2.num_classes, 3) - self.assertEqual(f1_obj2.average, None) - - def initialize_vars(self, average): - # initialize variables - f1_obj = F1Score(num_classes=3, average=average) - - self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables)) - - return f1_obj - - def update_obj_states(self, f1_obj, actuals, preds): - # update state variable values - update_op = f1_obj.update_state(actuals, preds) - self.evaluate(update_op) - - def check_results(self, f1_obj, value): - # check result - self.assertAllClose(value, self.evaluate(f1_obj.result()), atol=1e-5) - - def _test_f1(self, avg, act, pred, res): - f1_init = self.initialize_vars(avg) - self.update_obj_states(f1_init, act, pred) - self.check_results(f1_init, res) - - def _test_f1_score(self, actuals, preds, res): - # test for three average values with beta value as 1.0 - for avg in ['micro', 'macro', 'weighted']: - self._test_f1(avg, actuals, preds, res) - - # test for perfect f1 score - def test_f1_perfect_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - self._test_f1_score(actuals, preds, 1.0) - - # test for worst f1 score - def test_f1_worst_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) - self._test_f1_score(actuals, preds, 0.0) - - # test for random f1 score - def test_f1_random_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) - # Use absl parameterized test here if possible - test_params = [['micro', 0.6666667], ['macro', 0.65555555], - ['weighted', 0.67777777]] - - for avg, res in test_params: - self._test_f1(avg, actuals, preds, res) - - # test for random f1 score with average as None - def test_f1_random_score_none(self): - actuals = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], - dtype=tf.int32) - preds = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]], - dtype=tf.int32) - - # Use absl parameterized test here if possible - test_params = [[None, [0.8, 0.6666667, 1.]]] - - for avg, res in test_params: - self._test_f1(avg, actuals, preds, res) - - # Keras model check - def test_keras_model(self): - model = tf.keras.Sequential() - model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(1, activation='softmax')) - fb = F1Score(1, 'macro') - model.compile( - optimizer='rmsprop', - loss='categorical_crossentropy', - metrics=['acc', fb]) - # data preparation - data = np.random.random((10, 3)) - labels = np.random.random((10, 1)) - labels = np.where(labels > 0.5, 1, 0) - model.fit(data, labels, epochs=1, batch_size=32, verbose=0) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index b7328538d4..082143064e 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -12,220 +12,155 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implements F1 scores.""" +"""Implements F scores.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf -from tensorflow.keras.metrics import Metric import numpy as np -class FBetaScore(Metric): - """Computes F-Beta Score. +class FBetaScore(tf.keras.metrics.Metric): + """Computes F-Beta score. - This is the weighted harmonic mean of precision and recall. - Output range is [0, 1]. + It is the weighted harmonic mean of precision + and recall. Output range is [0, 1]. Works for + both multi-class and multi-label classification. - F-Beta = (1 + beta^2) * ((precision * recall) / - ((beta^2 * precision) + recall)) - - `beta` parameter determines the weight given to the - precision and recall. - - `beta < 1` gives more weight to the precision. - `beta > 1` gives more weight to the recall. - `beta == 1` gives equal weight to precision and recall. + F-Beta = (1 + beta^2) * (precision * recall) / ((beta^2 * precision) + recall) Args: - num_classes : Number of unique classes in the dataset. - average : Type of averaging to be performed on data. - Acceptable values are None, micro, macro and - weighted. - beta : float. Determines the weight of precision and recall - in harmonic mean. Acceptable values are either a number - of float data type greater than 0.0 or a scale tensor - of dtype tf.float32. + num_classes: Number of unique classes in the dataset. + average: Type of averaging to be performed on data. + Acceptable values are `None`, `micro`, `macro` and + `weighted`. Default value is None. + beta: Determines the weight of precision and recall + in harmonic mean. Determines the weight given to the + precision and recall. Default value is 1. + threshold: Elements of `y_pred` greater than threshold are + converted to be 1, and the rest 0. If threshold is + None, the argmax is converted to 1, and the rest 0. Returns: - F Beta Score: float + F-Beta Score: float Raises: - ValueError: If the `average` has values other than - [None, micro, macro, weighted]. + ValueError: If the `average` has values other than + [None, micro, macro, weighted]. - ValueError: If the `beta` value is less than or equal - to 0. + ValueError: If the `beta` value is less than or equal + to 0. `average` parameter behavior: + None: Scores for each class are returned + + micro: True positivies, false positives and + false negatives are computed globally. + + macro: True positivies, false positives and + false negatives are computed for each class + and their unweighted mean is returned. - 1. If `None` is specified as an input, scores for each - class are returned. - - 2. If `micro` is specified, metrics like true positivies, - false positives and false negatives are computed - globally. - - 3. If `macro` is specified, metrics like true positivies, - false positives and false negatives are computed for - each class and their unweighted mean is returned. - Imbalance in dataset is not taken into account for - calculating the score - - 4. If `weighted` is specified, metrics are computed for - each class and returns the mean weighted by the - number of true instances in each class taking data - imbalance into account. - - Usage: - ```python - actuals = tf.constant([[1, 1, 0],[1, 0, 0]], - dtype=tf.int32) - preds = tf.constant([[1, 0, 0],[1, 0, 1]], - dtype=tf.int32) - # F-Beta Micro - fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='micro') - fb_score.update_state(actuals, preds) - print('F1-Beta Score is: ', - fb_score.result().numpy()) # 0.6666666 - # F-Beta Macro - fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='macro') - fb_score.update_state(actuals, preds) - print('F1-Beta Score is: ', - fb_score.result().numpy()) # 0.33333334 - # F-Beta Weighted - fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='weighted') - fb_score.update_state(actuals, preds) - print('F1-Beta Score is: ', - fb_score.result().numpy()) # 0.6666667 - # F-Beta score for each class (average=None). - fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average=None) - fb_score.update_state(actuals, preds) - print('F1-Beta Score is: ', - fb_score.result().numpy()) # [1. 0. 0.] - ``` + weighted: Metrics are computed for each class + and returns the mean weighted by the + number of true instances in each class. """ def __init__(self, num_classes, average=None, beta=1.0, + threshold=None, name='fbeta_score', dtype=tf.float32): super(FBetaScore, self).__init__(name=name) - self.num_classes = num_classes - # type check - if not isinstance(beta, float) and beta.dtype != tf.float32: - raise TypeError("The value of beta should be float") - # value check - if beta <= 0.0: - raise ValueError("beta value should be greater than zero") - else: - self.beta = beta + if average not in (None, 'micro', 'macro', 'weighted'): raise ValueError("Unknown average type. Acceptable values " "are: [None, micro, macro, weighted]") - else: - self.average = average - if self.average == 'micro': - self.axis = None - else: - self.axis = 0 - if self.average == 'micro': - self.true_positives = self.add_weight( - 'true_positives', - shape=[], - initializer='zeros', - dtype=self.dtype) - self.false_positives = self.add_weight( - 'false_positives', - shape=[], - initializer='zeros', - dtype=self.dtype) - self.false_negatives = self.add_weight( - 'false_negatives', - shape=[], - initializer='zeros', - dtype=self.dtype) - else: - self.true_positives = self.add_weight( - 'true_positives', - shape=[self.num_classes], - initializer='zeros', - dtype=self.dtype) - self.false_positives = self.add_weight( - 'false_positives', - shape=[self.num_classes], - initializer='zeros', - dtype=self.dtype) - self.false_negatives = self.add_weight( - 'false_negatives', - shape=[self.num_classes], - initializer='zeros', - dtype=self.dtype) - self.weights_intermediate = self.add_weight( - 'weights', - shape=[self.num_classes], + + if not isinstance(beta, float): + raise TypeError("The value of beta should be a python float") + + if beta <= 0.0: + raise ValueError("beta value should be greater than zero") + + if threshold is not None: + if not isinstance(threshold, float): + raise TypeError( + "The value of threshold should be a python float") + if threshold > 1.0 or threshold <= 0.0: + raise ValueError("threshold should be between 0 and 1") + + self.num_classes = num_classes + self.average = average + self.beta = beta + self.threshold = threshold + self.axis = None + self.init_shape = [] + + if self.average != 'micro': + self.axis = 0 + self.init_shape = [self.num_classes] + + def _zero_wt_init(name): + return self.add_weight( + name, + shape=self.init_shape, initializer='zeros', dtype=self.dtype) + self.true_positives = _zero_wt_init('true_positives') + self.false_positives = _zero_wt_init('false_positives') + self.false_negatives = _zero_wt_init('false_negatives') + self.weights_intermediate = _zero_wt_init('weights_intermediate') + # TODO: Add sample_weight support, currently it is # ignored during calculations. def update_state(self, y_true, y_pred, sample_weight=None): + y_pred = tf.cast(y_pred, tf.float32) + + if self.threshold is None: + threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True) + # make sure [0, 0, 0] doesn't become [1, 1, 1] + # Use (x - 0 > eps) to check for fp zero equality + y_pred = tf.logical_and(y_pred >= threshold, y_pred - 0 > 1e-12) + else: + y_pred = y_pred > self.threshold + y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred, tf.int32) - # true positive - self.true_positives.assign_add( - tf.cast( - tf.math.count_nonzero(y_pred * y_true, axis=self.axis), - self.dtype)) - # false positive - self.false_positives.assign_add( - tf.cast( - tf.math.count_nonzero(y_pred * (y_true - 1), axis=self.axis), - self.dtype)) - # false negative - self.false_negatives.assign_add( - tf.cast( - tf.math.count_nonzero((y_pred - 1) * y_true, axis=self.axis), - self.dtype)) - if self.average == 'weighted': - # variable to hold intermediate weights - self.weights_intermediate.assign_add( - tf.cast(tf.reduce_sum(y_true, axis=self.axis), self.dtype)) + def _count_non_zero(val): + non_zeros = tf.math.count_nonzero(val, axis=self.axis) + return tf.cast(non_zeros, self.dtype) + + self.true_positives.assign_add(_count_non_zero(y_pred * y_true)) + self.false_positives.assign_add(_count_non_zero(y_pred * (y_true - 1))) + self.false_negatives.assign_add(_count_non_zero((y_pred - 1) * y_true)) + self.weights_intermediate.assign_add(_count_non_zero(y_true)) def result(self): - p_sum = tf.cast(self.true_positives + self.false_positives, self.dtype) - # calculate precision - precision = tf.math.divide_no_nan(self.true_positives, p_sum) - - r_sum = tf.cast(self.true_positives + self.false_negatives, self.dtype) - # calculate recall - recall = tf.math.divide_no_nan(self.true_positives, r_sum) - # intermediate calculations + precision = tf.math.divide_no_nan( + self.true_positives, self.true_positives + self.false_positives) + recall = tf.math.divide_no_nan( + self.true_positives, self.true_positives + self.false_negatives) + mul_value = precision * recall add_value = (tf.math.square(self.beta) * precision) + recall - f1_int = (1 + tf.math.square(self.beta)) * (tf.math.divide_no_nan( - mul_value, add_value)) - # f1 score - if self.average is not None: - f1_score = tf.reduce_mean(f1_int) - else: - f1_score = f1_int - # condition for weighted f1 score + mean = (tf.math.divide_no_nan(mul_value, add_value)) + f1_score = mean * (1 + tf.math.square(self.beta)) + if self.average == 'weighted': - f1_int_weights = tf.math.divide_no_nan( + weights = tf.math.divide_no_nan( self.weights_intermediate, tf.reduce_sum(self.weights_intermediate)) - # weighted f1 score calculation - f1_score = tf.reduce_sum(f1_int * f1_int_weights) + f1_score = tf.reduce_sum(f1_score * weights) + + elif self.average is not None: # [micro, macro] + f1_score = tf.reduce_mean(f1_score) return f1_score @@ -237,111 +172,69 @@ def get_config(self): "average": self.average, "beta": self.beta, } + + if self.threshold is not None: + config["threshold"] = self.threshold + base_config = super(FBetaScore, self).get_config() return dict(list(base_config.items()) + list(config.items())) def reset_states(self): - # reset state of the variables to zero - if self.average == 'micro': - self.true_positives.assign(0) - self.false_positives.assign(0) - self.false_negatives.assign(0) - else: - self.true_positives.assign(np.zeros(self.num_classes), np.float32) - self.false_positives.assign(np.zeros(self.num_classes), np.float32) - self.false_negatives.assign(np.zeros(self.num_classes), np.float32) - self.weights_intermediate.assign( - np.zeros(self.num_classes), np.float32) + self.true_positives.assign(tf.zeros(self.init_shape, self.dtype)) + self.false_positives.assign(tf.zeros(self.init_shape, self.dtype)) + self.false_negatives.assign(tf.zeros(self.init_shape, self.dtype)) + self.weights_intermediate.assign(tf.zeros(self.init_shape, self.dtype)) class F1Score(FBetaScore): - """Computes F1 micro, macro or weighted based on the user's choice. + """Computes F-1 Score. - F1 score is the weighted average of precision and - recall. Output range is [0, 1]. This works for both - multi-class and multi-label classification. + It is the harmonic mean of precision and recall. + Output range is [0, 1]. Works for both multi-class + and multi-label classification. - F-1 = (2) * ((precision * recall) / (precision + recall)) + F-1 = 2 * (precision * recall) / (precision + recall) Args: - num_classes : Number of unique classes in the dataset. - average : Type of averaging to be performed on data. - Acceptable values are `None`, `micro`, `macro` and - `weighted`. - Default value is None. - beta : float - Determines the weight of precision and recall in harmonic - mean. Its value is 1.0 for F1 score. + num_classes: Number of unique classes in the dataset. + average: Type of averaging to be performed on data. + Acceptable values are `None`, `micro`, `macro` + and `weighted`. Default value is None. + threshold: Elements of `y_pred` above threshold are + considered to be 1, and the rest 0. If threshold is + None, the argmax is converted to 1, and the rest 0. Returns: - F1 Score: float + F-1 Score: float Raises: - ValueError: If the `average` has values other than - [None, micro, macro, weighted]. - - ValueError: If the `beta` value is less than or equal - to 0. + ValueError: If the `average` has values other than + [None, micro, macro, weighted]. `average` parameter behavior: + None: Scores for each class are returned - 1. If `None` is specified as an input, scores for each - class are returned. - - 2. If `micro` is specified, metrics like true positivies, - false positives and false negatives are computed - globally. - - 3. If `macro` is specified, metrics like true positivies, - false positives and false negatives are computed for - each class and their unweighted mean is returned. - Imbalance in dataset is not taken into account for - calculating the score - - 4. If `weighted` is specified, metrics are computed for - each class and returns the mean weighted by the - number of true instances in each class taking data - imbalance into account. - - Usage: - ```python - actuals = tf.constant([[1, 1, 0],[1, 0, 0]], - dtype=tf.int32) - preds = tf.constant([[1, 0, 0],[1, 0, 1]], - dtype=tf.int32) - # F1 Micro - output = tfa.metrics.F1Score(num_classes=3, - average='micro') - output.update_state(actuals, preds) - print('F1 Micro score is: ', - output.result().numpy()) # 0.6666667 - # F1 Macro - output = tfa.metrics.F1Score(num_classes=3, - average='macro') - output.update_state(actuals, preds) - print('F1 Macro score is: ', - output.result().numpy()) # 0.33333334 - # F1 weighted - output = tfa.metrics.F1Score(num_classes=3, - average='weighted') - output.update_state(actuals, preds) - print('F1 Weighted score is: ', - output.result().numpy()) # 0.6666667 - # F1 score for each class (average=None). - output = tfa.metrics.F1Score(num_classes=3) - output.update_state(actuals, preds) - print('F1 score is: ', - output.result().numpy()) # [1. 0. 0.] - ``` + micro: True positivies, false positives and + false negatives are computed globally. + + macro: True positivies, false positives and + false negatives are computed for each class + and their unweighted mean is returned. + + weighted: Metrics are computed for each class + and returns the mean weighted by the + number of true instances in each class. """ - def __init__(self, num_classes, average, name='f1_score', + def __init__(self, + num_classes, + average, + threshold=None, + name='f1_score', dtype=tf.float32): super(F1Score, self).__init__( - num_classes, average, 1.0, name=name, dtype=dtype) + num_classes, average, 1.0, threshold, name=name, dtype=dtype) - # TODO: Add sample_weight support, currently it is - # ignored during calculations. def get_config(self): base_config = super(F1Score, self).get_config() del base_config["beta"] diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py new file mode 100644 index 0000000000..c046bd793b --- /dev/null +++ b/tensorflow_addons/metrics/f_test.py @@ -0,0 +1,109 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests F beta metrics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.metrics import FBetaScore, utils +from tensorflow_addons.utils import test_utils + +import numpy as np +from sklearn.metrics import fbeta_score + + +@test_utils.run_all_in_graph_and_eager_modes +class FBetaScoreTest(tf.test.TestCase): + def test_config(self): + fbeta_obj = FBetaScore( + num_classes=3, beta=0.5, threshold=0.3, average=None) + self.assertEqual(fbeta_obj.beta, 0.5) + self.assertEqual(fbeta_obj.average, None) + self.assertEqual(fbeta_obj.threshold, 0.3) + self.assertEqual(fbeta_obj.num_classes, 3) + self.assertEqual(fbeta_obj.dtype, tf.float32) + + # Check save and restore config + fbeta_obj2 = FBetaScore.from_config(fbeta_obj.get_config()) + self.assertEqual(fbeta_obj2.beta, 0.5) + self.assertEqual(fbeta_obj2.average, None) + self.assertEqual(fbeta_obj2.threshold, 0.3) + self.assertEqual(fbeta_obj2.num_classes, 3) + self.assertEqual(fbeta_obj2.dtype, tf.float32) + + def _test_tf(self, avg, beta, act, pred, threshold): + act = tf.constant(act, tf.float32) + pred = tf.constant(pred, tf.float32) + + fbeta = FBetaScore(3, avg, beta, threshold) + update_op = fbeta.update_state(act, pred) + + self.evaluate(tf.compat.v1.variables_initializer(fbeta.variables)) + self.evaluate(update_op) + return self.evaluate(fbeta.result()) + + def _test_sk(self, avg, beta, act, pred, threshold): + act = np.array(act) + pred = np.array(pred) + if threshold is None: + threshold = np.max(pred, axis=-1, keepdims=True) + pred = np.logical_and(pred >= threshold, + pred - 0 > 1e-12).astype('int') + else: + pred = (pred >= threshold).astype('int') + + res = fbeta_score(act, pred, beta, average=avg) + return res + + def _test_fbeta_score(self, actuals, preds, threshold=None): + for avg in [None, 'micro', 'macro', 'weighted']: + for beta_val in [0.5, 1.0, 2.0]: + tf_score = self._test_tf(avg, beta_val, actuals, preds, + threshold) + sk_score = self._test_sk(avg, beta_val, actuals, preds, + threshold) + self.assertAllClose(tf_score, sk_score, atol=1e-5) + + def test_fbeta_perfect_score(self): + preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] + self._test_fbeta_score(actuals, preds, 0.66) + + def test_fbeta_worst_score(self): + preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + actuals = [[0, 0, 0], [0, 1, 0], [0, 0, 1]] + self._test_fbeta_score(actuals, preds, 0.66) + + def test_fbeta_random_score(self): + preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] + actuals = [[0, 0, 1], [1, 1, 0], [1, 1, 1]] + self._test_fbeta_score(actuals, preds, 0.66) + + def test_fbeta_random_score_none(self): + preds = [[0.9, 0.1, 0], [0.2, 0.6, 0.2], [0, 0, 1], [0.4, 0.3, 0.3], + [0, 0.9, 0.1], [0, 0, 1]] + actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], + [0, 0, 1]] + self._test_fbeta_score(actuals, preds, None) + + def test_keras_model(self): + fbeta = FBetaScore(5, 'micro', 1.0) + utils.test_keras_model(fbeta, 5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_addons/metrics/fbeta_test.py b/tensorflow_addons/metrics/fbeta_test.py deleted file mode 100644 index 69a5b730d1..0000000000 --- a/tensorflow_addons/metrics/fbeta_test.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests F beta metrics.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf -from tensorflow_addons.metrics import FBetaScore -from tensorflow_addons.utils import test_utils -from tensorflow.keras import layers -import numpy as np - - -@test_utils.run_all_in_graph_and_eager_modes -class FBetaScoreTest(tf.test.TestCase): - def test_config(self): - fbeta_obj = FBetaScore(num_classes=3, beta=0.5, average=None) - self.assertEqual(fbeta_obj.beta, 0.5) - self.assertEqual(fbeta_obj.average, None) - self.assertEqual(fbeta_obj.num_classes, 3) - self.assertEqual(fbeta_obj.dtype, tf.float32) - # Check save and restore config - fbeta_obj2 = FBetaScore.from_config(fbeta_obj.get_config()) - self.assertEqual(fbeta_obj2.beta, 0.5) - self.assertEqual(fbeta_obj2.average, None) - self.assertEqual(fbeta_obj2.num_classes, 3) - self.assertEqual(fbeta_obj2.dtype, tf.float32) - - def initialize_vars(self, beta_val, average): - # initialize variables - fbeta_obj = FBetaScore(num_classes=3, beta=beta_val, average=average) - - self.evaluate(tf.compat.v1.variables_initializer(fbeta_obj.variables)) - - return fbeta_obj - - def update_obj_states(self, fbeta_obj, actuals, preds): - # update state variables values - update_op = fbeta_obj.update_state(actuals, preds) - self.evaluate(update_op) - - def check_results(self, fbeta_obj, value): - # check results - self.assertAllClose( - value, self.evaluate(fbeta_obj.result()), atol=1e-5) - - def _test_fbeta(self, avg, beta, act, pred, res): - fbeta = self.initialize_vars(beta, avg) - self.update_obj_states(fbeta, act, pred) - self.check_results(fbeta, res) - - def _test_fbeta_score(self, actuals, preds, res): - # This function tests for three average values and - # two beta values - for avg in ['micro', 'macro', 'weighted']: - for beta_val in [0.5, 2.0]: - self._test_fbeta(avg, beta_val, actuals, preds, res) - - # test for the perfect score - def test_fbeta_perfect_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - self._test_fbeta_score(actuals, preds, 1.0) - - # test for the worst score - def test_fbeta_worst_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) - self._test_fbeta_score(actuals, preds, 0.0) - - # test for the random score - def test_fbeta_random_score(self): - actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], - dtype=tf.int32) - preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) - - # test parameters - test_params = [['micro', 0.5, 0.666667], ['macro', 0.5, 0.654882], - ['weighted', 0.5, 0.71380], ['micro', 2.0, 0.666667], - ['macro', 2.0, 0.68253], ['weighted', 2.0, 0.66269]] - - for avg, beta, res in test_params: - self._test_fbeta(avg, beta, actuals, preds, res) - - # Test for the random score with average value as None - def test_fbeta_random_score_none(self): - actuals = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], - dtype=tf.int32) - preds = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]], - dtype=tf.int32) - - # test parameters - test_params = [[0.5, [0.71428573, 0.8333334, 1.]], - [2.0, [0.90909094, 0.5555556, 1.]]] - - for beta, res in test_params: - self._test_fbeta(None, beta, actuals, preds, res) - - # Keras model check - def test_keras_model(self): - model = tf.keras.Sequential() - model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(1, activation='softmax')) - fb = FBetaScore(1, 'macro') - model.compile( - optimizer='rmsprop', - loss='categorical_crossentropy', - metrics=['acc', fb]) - # data preparation - data = np.random.random((10, 3)) - labels = np.random.random((10, 1)) - labels = np.where(labels > 0.5, 1, 0) - model.fit(data, labels, epochs=1, batch_size=32, verbose=0) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow_addons/metrics/utils.py b/tensorflow_addons/metrics/utils.py index 133fd88e37..65bfa9a74f 100644 --- a/tensorflow_addons/metrics/utils.py +++ b/tensorflow_addons/metrics/utils.py @@ -18,6 +18,7 @@ from __future__ import print_function import six +import numpy as np import tensorflow as tf @@ -65,3 +66,18 @@ def get_config(self): config[k] = v base_config = super(MeanMetricWrapper, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +def test_keras_model(metric, num_output): + # Test API comptibility with tf.keras Model + model = tf.keras.Sequential() + model.add(tf.keras.layers.Dense(64, activation='relu')) + model.add(tf.keras.layers.Dense(num_output, activation='softmax')) + model.compile( + optimizer='adam', + loss='categorical_crossentropy', + metrics=['acc', metric]) + + data = np.random.random((10, 3)) + labels = np.random.random((10, num_output)) + model.fit(data, labels, epochs=1, batch_size=5, verbose=0) From 00d7867f05429b36d77bd0b0ef96cf7b35bcdbe0 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Wed, 11 Sep 2019 20:17:13 +0530 Subject: [PATCH 2/8] Format files --- tensorflow_addons/metrics/f_scores.py | 3 +-- tensorflow_addons/metrics/f_test.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index 082143064e..b297a021e5 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -19,7 +19,6 @@ from __future__ import print_function import tensorflow as tf -import numpy as np class FBetaScore(tf.keras.metrics.Metric): @@ -29,7 +28,7 @@ class FBetaScore(tf.keras.metrics.Metric): and recall. Output range is [0, 1]. Works for both multi-class and multi-label classification. - F-Beta = (1 + beta^2) * (precision * recall) / ((beta^2 * precision) + recall) + F-Beta = (1 + beta^2) * (prec * recall) / ((beta^2 * prec) + recall) Args: num_classes: Number of unique classes in the dataset. diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index c046bd793b..e1c9046ef8 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -50,10 +50,8 @@ def _test_tf(self, avg, beta, act, pred, threshold): pred = tf.constant(pred, tf.float32) fbeta = FBetaScore(3, avg, beta, threshold) - update_op = fbeta.update_state(act, pred) - self.evaluate(tf.compat.v1.variables_initializer(fbeta.variables)) - self.evaluate(update_op) + self.evaluate(fbeta.update_state(act, pred)) return self.evaluate(fbeta.result()) def _test_sk(self, avg, beta, act, pred, threshold): From cb2b60c7191d45c8fab9e9929457256ef6b855c8 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Fri, 13 Sep 2019 17:27:00 +0530 Subject: [PATCH 3/8] Add F1 score test * Register FBetaScore and F1Score as Keras custom objects * Update readme to separate both metrics --- tensorflow_addons/metrics/README.md | 3 ++- tensorflow_addons/metrics/f_scores.py | 5 ++++- tensorflow_addons/metrics/f_test.py | 23 ++++++++++++++++++++++- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/metrics/README.md b/tensorflow_addons/metrics/README.md index db1533ec75..999a413ed3 100644 --- a/tensorflow_addons/metrics/README.md +++ b/tensorflow_addons/metrics/README.md @@ -12,7 +12,8 @@ | Submodule | Metric | Reference | |:----------------------- |:-------------------|:---------------| | cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)| -| f1_scores| F1 micro, macro and weighted| [F1 Score](https://en.wikipedia.org/wiki/F1_score)| +| f_scores| F1Score | [F1 Score](https://en.wikipedia.org/wiki/F1_score)| +| f_scores| FBetaScore | | | r_square| RSquare|[R-Sqaure](https://en.wikipedia.org/wiki/Coefficient_of_determination)| | multilabel_confusion_matrix | Multilabel Confusion Matrix | [mcm](https://en.wikipedia.org/wiki/Confusion_matrix)| diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index b297a021e5..9be60b9742 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -19,8 +19,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.utils import keras_utils +@keras_utils.register_keras_custom_object class FBetaScore(tf.keras.metrics.Metric): """Computes F-Beta score. @@ -185,6 +187,7 @@ def reset_states(self): self.weights_intermediate.assign(tf.zeros(self.init_shape, self.dtype)) +@keras_utils.register_keras_custom_object class F1Score(FBetaScore): """Computes F-1 Score. @@ -227,7 +230,7 @@ class F1Score(FBetaScore): def __init__(self, num_classes, - average, + average=None, threshold=None, name='f1_score', dtype=tf.float32): diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index e1c9046ef8..ab11144343 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -19,7 +19,7 @@ from __future__ import print_function import tensorflow as tf -from tensorflow_addons.metrics import FBetaScore, utils +from tensorflow_addons.metrics import FBetaScore, F1Score, utils from tensorflow_addons.utils import test_utils import numpy as np @@ -103,5 +103,26 @@ def test_keras_model(self): utils.test_keras_model(fbeta, 5) +@test_utils.run_all_in_graph_and_eager_modes +class F1ScoreTest(tf.test.TestCase): + def test_eq(self): + f1 = F1Score(4) + fbeta = FBetaScore(4, beta=1.0) + self.evaluate(tf.compat.v1.variables_initializer(f1.variables)) + self.evaluate(tf.compat.v1.variables_initializer(fbeta.variables)) + + actuals = np.random.randint(2, size=(10, 4)) + preds = np.random.uniform(size=(10, 4)) + + self.evaluate(fbeta.update_state(actuals, preds)) + self.evaluate(f1.update_state(actuals, preds)) + self.assertAllClose( + self.evaluate(fbeta.result()), self.evaluate(f1.result())) + + def test_keras_model(self): + f1 = F1Score(5) + utils.test_keras_model(f1, 5) + + if __name__ == '__main__': tf.test.main() From 23250f8a26003a5e41ceec5fcbaa256b3a16e6bc Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Mon, 30 Sep 2019 19:18:46 +0530 Subject: [PATCH 4/8] Add test for F1-score get_config --- tensorflow_addons/metrics/f_scores.py | 7 +++---- tensorflow_addons/metrics/f_test.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index 9be60b9742..8b78b41a85 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -121,13 +121,12 @@ def _zero_wt_init(name): # TODO: Add sample_weight support, currently it is # ignored during calculations. def update_state(self, y_true, y_pred, sample_weight=None): - y_pred = tf.cast(y_pred, tf.float32) - if self.threshold is None: threshold = tf.reduce_max(y_pred, axis=-1, keepdims=True) # make sure [0, 0, 0] doesn't become [1, 1, 1] - # Use (x - 0 > eps) to check for fp zero equality - y_pred = tf.logical_and(y_pred >= threshold, y_pred - 0 > 1e-12) + # Use abs(x) > eps, instead of x != 0 to check for zero + y_pred = tf.logical_and(y_pred >= threshold, + tf.abs(y_pred) > 1e-12) else: y_pred = y_pred > self.threshold diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index ab11144343..8fe2a1cf76 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -67,19 +67,17 @@ def _test_sk(self, avg, beta, act, pred, threshold): res = fbeta_score(act, pred, beta, average=avg) return res - def _test_fbeta_score(self, actuals, preds, threshold=None): + def _test_fbeta_score(self, actuals, preds, result, threshold=None): for avg in [None, 'micro', 'macro', 'weighted']: for beta_val in [0.5, 1.0, 2.0]: tf_score = self._test_tf(avg, beta_val, actuals, preds, threshold) - sk_score = self._test_sk(avg, beta_val, actuals, preds, - threshold) - self.assertAllClose(tf_score, sk_score, atol=1e-5) + self.assertAllClose(tf_score, result, atol=1e-5) def test_fbeta_perfect_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - self._test_fbeta_score(actuals, preds, 0.66) + self._test_fbeta_score(actuals, preds, 0.0, 0.66) def test_fbeta_worst_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] @@ -106,13 +104,15 @@ def test_keras_model(self): @test_utils.run_all_in_graph_and_eager_modes class F1ScoreTest(tf.test.TestCase): def test_eq(self): - f1 = F1Score(4) - fbeta = FBetaScore(4, beta=1.0) + f1 = F1Score(3) + fbeta = FBetaScore(3, beta=1.0) self.evaluate(tf.compat.v1.variables_initializer(f1.variables)) self.evaluate(tf.compat.v1.variables_initializer(fbeta.variables)) - actuals = np.random.randint(2, size=(10, 4)) - preds = np.random.uniform(size=(10, 4)) + preds = [[0.9, 0.1, 0], [0.2, 0.6, 0.2], [0, 0, 1], [0.4, 0.3, 0.3], + [0, 0.9, 0.1], [0, 0, 1]] + actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], + [0, 0, 1]] self.evaluate(fbeta.update_state(actuals, preds)) self.evaluate(f1.update_state(actuals, preds)) @@ -123,6 +123,11 @@ def test_keras_model(self): f1 = F1Score(5) utils.test_keras_model(f1, 5) + def test_config(self): + f1 = F1Score(3) + config = f1.get_config() + self.assertFalse("beta" in config) + if __name__ == '__main__': tf.test.main() From 89a97ebf493731026d42686a1ecd4a7c2b92db31 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Mon, 30 Sep 2019 20:04:18 +0530 Subject: [PATCH 5/8] FIX: Use sk_score for true value --- tensorflow_addons/metrics/f_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index 8fe2a1cf76..c2c0223723 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -67,17 +67,19 @@ def _test_sk(self, avg, beta, act, pred, threshold): res = fbeta_score(act, pred, beta, average=avg) return res - def _test_fbeta_score(self, actuals, preds, result, threshold=None): + def _test_fbeta_score(self, actuals, preds, threshold=None): for avg in [None, 'micro', 'macro', 'weighted']: for beta_val in [0.5, 1.0, 2.0]: tf_score = self._test_tf(avg, beta_val, actuals, preds, threshold) - self.assertAllClose(tf_score, result, atol=1e-5) + sk_score = self._test_sk(avg, beta_val, actuals, preds, + threshold) + self.assertAllClose(tf_score, sk_score, atol=1e-5) def test_fbeta_perfect_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - self._test_fbeta_score(actuals, preds, 0.0, 0.66) + self._test_fbeta_score(actuals, preds, 0.66) def test_fbeta_worst_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] From 52f7b348c22bf0d60a3b3a1c7872aa5788fd4a70 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Fri, 1 Nov 2019 17:27:02 +0530 Subject: [PATCH 6/8] Remove sklearn from f_test Resort to using hard coded test cases rather than comparing with sklearn --- tensorflow_addons/metrics/f_test.py | 75 ++++++++++++++++++----------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index c2c0223723..a44ea874ae 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +from absl.testing import parameterized + import tensorflow as tf from tensorflow_addons.metrics import FBetaScore, F1Score, utils from tensorflow_addons.utils import test_utils @@ -27,7 +29,7 @@ @test_utils.run_all_in_graph_and_eager_modes -class FBetaScoreTest(tf.test.TestCase): +class FBetaScoreTest(tf.test.TestCase, parameterized.TestCase): def test_config(self): fbeta_obj = FBetaScore( num_classes=3, beta=0.5, threshold=0.3, average=None) @@ -54,49 +56,64 @@ def _test_tf(self, avg, beta, act, pred, threshold): self.evaluate(fbeta.update_state(act, pred)) return self.evaluate(fbeta.result()) - def _test_sk(self, avg, beta, act, pred, threshold): - act = np.array(act) - pred = np.array(pred) - if threshold is None: - threshold = np.max(pred, axis=-1, keepdims=True) - pred = np.logical_and(pred >= threshold, - pred - 0 > 1e-12).astype('int') - else: - pred = (pred >= threshold).astype('int') - - res = fbeta_score(act, pred, beta, average=avg) - return res - - def _test_fbeta_score(self, actuals, preds, threshold=None): - for avg in [None, 'micro', 'macro', 'weighted']: - for beta_val in [0.5, 1.0, 2.0]: - tf_score = self._test_tf(avg, beta_val, actuals, preds, - threshold) - sk_score = self._test_sk(avg, beta_val, actuals, preds, - threshold) - self.assertAllClose(tf_score, sk_score, atol=1e-5) + def _test_fbeta_score(self, actuals, preds, avg, beta_val, result, + threshold): + tf_score = self._test_tf(avg, beta_val, actuals, preds, threshold) + self.assertAllClose(tf_score, result, atol=1e-7) def test_fbeta_perfect_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] actuals = [[1, 1, 1], [1, 0, 0], [1, 1, 0]] - self._test_fbeta_score(actuals, preds, 0.66) + + for avg_val in ['micro', 'macro', 'weighted']: + for beta in [0.5, 1.0, 2.0]: + self._test_fbeta_score(actuals, preds, avg_val, beta, 1.0, + 0.66) def test_fbeta_worst_score(self): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] actuals = [[0, 0, 0], [0, 1, 0], [0, 0, 1]] - self._test_fbeta_score(actuals, preds, 0.66) - def test_fbeta_random_score(self): + for avg_val in ['micro', 'macro', 'weighted']: + for beta in [0.5, 1.0, 2.0]: + self._test_fbeta_score(actuals, preds, avg_val, beta, 0.0, + 0.66) + + @parameterized.parameters([[None, 0.5, [0.71428573, 0.5, 0.833334]], + [None, 1.0, [0.8, 0.5, 0.6666667]], + [None, 2.0, [0.9090904, 0.5, 0.555556]], + ['micro', 0.5, 0.6666667], + ['micro', 1.0, 0.6666667], + ['micro', 2.0, 0.6666667], + ['macro', 0.5, 0.6825397], + ['macro', 1.0, 0.6555555], + ['macro', 2.0, 0.6548822], + ['weighted', 0.5, 0.6825397], + ['weighted', 1.0, 0.6555555], + ['weighted', 2.0, 0.6548822]]) + def test_fbeta_random_score(self, avg_val, beta, result): preds = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]] actuals = [[0, 0, 1], [1, 1, 0], [1, 1, 1]] - self._test_fbeta_score(actuals, preds, 0.66) - - def test_fbeta_random_score_none(self): + self._test_fbeta_score(actuals, preds, avg_val, beta, result, 0.66) + + @parameterized.parameters([[None, 0.5, [0.9090904, 0.555556, 1.0]], + [None, 1.0, [0.8, 0.6666667, 1.0]], + [None, 2.0, [0.71428573, 0.833334, 1.0]], + ['micro', 0.5, 0.833334], + ['micro', 1.0, 0.833334], + ['micro', 2.0, 0.833334], + ['macro', 0.5, 0.821549], + ['macro', 1.0, 0.822222], + ['macro', 2.0, 0.849206], + ['weighted', 0.5, 0.880471], + ['weighted', 1.0, 0.844445], + ['weighted', 2.0, 0.829365]]) + def test_fbeta_random_score_none(self, avg_val, beta, result): preds = [[0.9, 0.1, 0], [0.2, 0.6, 0.2], [0, 0, 1], [0.4, 0.3, 0.3], [0, 0.9, 0.1], [0, 0, 1]] actuals = [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]] - self._test_fbeta_score(actuals, preds, None) + self._test_fbeta_score(actuals, preds, avg_val, beta, result, None) def test_keras_model(self): fbeta = FBetaScore(5, 'micro', 1.0) From 85223c289cca9095b724aef13f1807c911ae6562 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Fri, 1 Nov 2019 17:49:09 +0530 Subject: [PATCH 7/8] Remove unused import --- tensorflow_addons/metrics/f_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index a44ea874ae..112648e6f7 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -24,9 +24,6 @@ from tensorflow_addons.metrics import FBetaScore, F1Score, utils from tensorflow_addons.utils import test_utils -import numpy as np -from sklearn.metrics import fbeta_score - @test_utils.run_all_in_graph_and_eager_modes class FBetaScoreTest(tf.test.TestCase, parameterized.TestCase): From 0844912ab4ac69a083b8430bb5b3085ccc966916 Mon Sep 17 00:00:00 2001 From: Dheeraj Date: Fri, 1 Nov 2019 17:52:39 +0530 Subject: [PATCH 8/8] Rename test_keras_model -> _get_model --- tensorflow_addons/metrics/f_test.py | 4 ++-- tensorflow_addons/metrics/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/metrics/f_test.py b/tensorflow_addons/metrics/f_test.py index 112648e6f7..a2a51da41e 100644 --- a/tensorflow_addons/metrics/f_test.py +++ b/tensorflow_addons/metrics/f_test.py @@ -114,7 +114,7 @@ def test_fbeta_random_score_none(self, avg_val, beta, result): def test_keras_model(self): fbeta = FBetaScore(5, 'micro', 1.0) - utils.test_keras_model(fbeta, 5) + utils._get_model(fbeta, 5) @test_utils.run_all_in_graph_and_eager_modes @@ -137,7 +137,7 @@ def test_eq(self): def test_keras_model(self): f1 = F1Score(5) - utils.test_keras_model(f1, 5) + utils._get_model(f1, 5) def test_config(self): f1 = F1Score(3) diff --git a/tensorflow_addons/metrics/utils.py b/tensorflow_addons/metrics/utils.py index 65bfa9a74f..9f4be1d5fd 100644 --- a/tensorflow_addons/metrics/utils.py +++ b/tensorflow_addons/metrics/utils.py @@ -68,7 +68,7 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) -def test_keras_model(metric, num_output): +def _get_model(metric, num_output): # Test API comptibility with tf.keras Model model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(64, activation='relu'))