From d45bbc01ad547d8c57f6ac89c45b2616654278fb Mon Sep 17 00:00:00 2001 From: saishruthi Date: Tue, 10 Sep 2019 17:28:46 -0700 Subject: [PATCH 1/2] adding threshold parameter to f1 --- tensorflow_addons/metrics/f1_test.py | 33 ++++++++------ tensorflow_addons/metrics/f_scores.py | 59 ++++++++++++++----------- tensorflow_addons/metrics/fbeta_test.py | 37 ++++++++++------ 3 files changed, 78 insertions(+), 51 deletions(-) diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py index e11165bb2c..92e8357d4b 100755 --- a/tensorflow_addons/metrics/f1_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -28,7 +28,7 @@ @test_utils.run_all_in_graph_and_eager_modes class F1ScoreTest(tf.test.TestCase): def test_config(self): - f1_obj = F1Score(num_classes=3, average=None) + f1_obj = F1Score(num_classes=3, threshold=0.75, average=None) self.assertEqual(f1_obj.name, 'f1_score') self.assertEqual(f1_obj.dtype, tf.float32) self.assertEqual(f1_obj.num_classes, 3) @@ -42,7 +42,7 @@ def test_config(self): def initialize_vars(self, average): # initialize variables - f1_obj = F1Score(num_classes=3, average=average) + f1_obj = F1Score(num_classes=3, threshold=0.75, average=average) self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables)) @@ -71,21 +71,24 @@ def _test_f1_score(self, actuals, preds, res): def test_f1_perfect_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) + preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], + dtype=tf.float32) self._test_f1_score(actuals, preds, 1.0) # test for worst f1 score def test_f1_worst_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) + preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], + dtype=tf.float32) self._test_f1_score(actuals, preds, 0.0) # test for random f1 score def test_f1_random_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) + preds = tf.constant([[0.4, 0.7, 1], [1, 0.8, 0], [1, 0.9, 0.8]], + dtype=tf.float32) # Use absl parameterized test here if possible test_params = [['micro', 0.6666667], ['macro', 0.65555555], ['weighted', 0.67777777]] @@ -98,9 +101,9 @@ def test_f1_random_score_none(self): actuals = tf.constant( [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) - preds = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]], - dtype=tf.int32) + preds = tf.constant([[0.8, 0.1, 0.1], [0, 1, 0], [0.11, 0.13, 0.76], + [0.8, 0.2, 0], [0.99, 0.05, 0.05], [0, 0, 1]], + dtype=tf.float32) # Use absl parameterized test here if possible test_params = [[None, [0.8, 0.6666667, 1.]]] @@ -113,17 +116,21 @@ def test_keras_model(self): model = tf.keras.Sequential() model.add(layers.Dense(64, activation='relu')) model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(1, activation='softmax')) - fb = F1Score(1, 'macro') + model.add(layers.Dense(2, activation='softmax')) + + f1 = F1Score(num_classes=2, threshold=0.6, average='weighted') + model.compile( optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc', fb]) - # data preparation data = np.random.random((10, 3)) - labels = np.random.random((10, 1)) + + labels = np.random.random((10, 2)) labels = np.where(labels > 0.5, 1, 0) - model.fit(data, labels, epochs=1, batch_size=32, verbose=0) + + fitted_model = model.fit( + data, labels, epochs=1, batch_size=32, verbose=0) if __name__ == '__main__': diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index 5409d79325..d487714722 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -41,13 +41,15 @@ class FBetaScore(Metric): Args: num_classes : Number of unique classes in the dataset. + threshold: Float representing the threshold for deciding whether + prediction values are 1 or 0. average : Type of averaging to be performed on data. - Acceptable values are None, micro, macro and - weighted. + Acceptable values are None, micro, macro and + weighted. beta : float. Determines the weight of precision and recall - in harmonic mean. Acceptable values are either a number - of float data type greater than 0.0 or a scale tensor - of dtype tf.float32. + in harmonic mean. Acceptable values are either a number + of float data type greater than 0.0 or a scale tensor + of dtype tf.float32. Returns: F Beta Score: float @@ -82,30 +84,30 @@ class are returned. Usage: ```python actuals = tf.constant([[1, 1, 0],[1, 0, 0]], - dtype=tf.int32) - preds = tf.constant([[1, 0, 0],[1, 0, 1]], - dtype=tf.int32) + dtype=tf.float32) + preds = tf.constant([[0.9, 0.2, 0.2],[0.82, 0.3, 0.85]], + dtype=tf.float32) # F-Beta Micro fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='micro') + beta=0.4, threshold = 0.8, average='micro') fb_score.update_state(actuals, preds) print('F1-Beta Score is: ', fb_score.result().numpy()) # 0.6666666 # F-Beta Macro fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='macro') + beta=0.4, threshold = 0.8, average='macro') fb_score.update_state(actuals, preds) print('F1-Beta Score is: ', fb_score.result().numpy()) # 0.33333334 # F-Beta Weighted fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average='weighted') + beta=0.4, threshold = 0.8, average='weighted') fb_score.update_state(actuals, preds) print('F1-Beta Score is: ', fb_score.result().numpy()) # 0.6666667 # F-Beta score for each class (average=None). fb_score = tfa.metrics.FBetaScore(num_classes=3, - beta=0.4, average=None) + beta=0.4, threshold = 0.8, average=None) fb_score.update_state(actuals, preds) print('F1-Beta Score is: ', fb_score.result().numpy()) # [1. 0. 0.] @@ -116,10 +118,12 @@ def __init__(self, num_classes, average=None, beta=1.0, + threshold=0.8, name='fbeta_score', dtype=tf.float32): super(FBetaScore, self).__init__(name=name) self.num_classes = num_classes + self.threshold = threshold # type check if not isinstance(beta, float) and beta.dtype != tf.float32: raise TypeError("The value of beta should be float") @@ -175,11 +179,10 @@ def __init__(self, initializer='zeros', dtype=self.dtype) - # TODO: Add sample_weight support, currently it is - # ignored during calculations. + # TO DO SSaishruthi: Add sample weight option def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, tf.int32) - y_pred = tf.cast(y_pred, tf.int32) + y_pred = tf.cast(y_pred > self.threshold, tf.int32) # true positive self.true_positives.assign_add( @@ -234,6 +237,7 @@ def get_config(self): config = { "num_classes": self.num_classes, + "threshold": self.threshold, "average": self.average, "beta": self.beta, } @@ -265,6 +269,8 @@ class F1Score(FBetaScore): Args: num_classes : Number of unique classes in the dataset. + threshold: Float representing the threshold for deciding whether + prediction values are 1 or 0. average : Type of averaging to be performed on data. Acceptable values are `None`, `micro`, `macro` and `weighted`. @@ -307,41 +313,44 @@ class are returned. ```python actuals = tf.constant([[1, 1, 0],[1, 0, 0]], dtype=tf.int32) - preds = tf.constant([[1, 0, 0],[1, 0, 1]], - dtype=tf.int32) + preds = tf.constant([[0.9, 0.4, 0.45],[1, 0.2, 0.87]], + dtype=tf.float32) # F1 Micro - output = tfa.metrics.F1Score(num_classes=3, + output = tfa.metrics.F1Score(num_classes=3, threshold=0.8, average='micro') output.update_state(actuals, preds) print('F1 Micro score is: ', output.result().numpy()) # 0.6666667 # F1 Macro - output = tfa.metrics.F1Score(num_classes=3, + output = tfa.metrics.F1Score(num_classes=3, threshold=0.8, average='macro') output.update_state(actuals, preds) print('F1 Macro score is: ', output.result().numpy()) # 0.33333334 # F1 weighted - output = tfa.metrics.F1Score(num_classes=3, + output = tfa.metrics.F1Score(num_classes=3, threshold=0.8, average='weighted') output.update_state(actuals, preds) print('F1 Weighted score is: ', output.result().numpy()) # 0.6666667 # F1 score for each class (average=None). - output = tfa.metrics.F1Score(num_classes=3) + output = tfa.metrics.F1Score(num_classes=3, threshold=0.8, + average=None) output.update_state(actuals, preds) print('F1 score is: ', output.result().numpy()) # [1. 0. 0.] ``` """ - def __init__(self, num_classes, average, name='f1_score', + def __init__(self, + num_classes, + threshold, + average, + name='f1_score', dtype=tf.float32): super(F1Score, self).__init__( - num_classes, average, 1.0, name=name, dtype=dtype) + num_classes, average, 1.0, threshold, name=name, dtype=dtype) - # TODO: Add sample_weight support, currently it is - # ignored during calculations. def get_config(self): base_config = super(F1Score, self).get_config() del base_config["beta"] diff --git a/tensorflow_addons/metrics/fbeta_test.py b/tensorflow_addons/metrics/fbeta_test.py index 69a5b730d1..7912e96b44 100644 --- a/tensorflow_addons/metrics/fbeta_test.py +++ b/tensorflow_addons/metrics/fbeta_test.py @@ -28,7 +28,8 @@ @test_utils.run_all_in_graph_and_eager_modes class FBetaScoreTest(tf.test.TestCase): def test_config(self): - fbeta_obj = FBetaScore(num_classes=3, beta=0.5, average=None) + fbeta_obj = FBetaScore( + num_classes=3, beta=0.5, threshold=0.8, average=None) self.assertEqual(fbeta_obj.beta, 0.5) self.assertEqual(fbeta_obj.average, None) self.assertEqual(fbeta_obj.num_classes, 3) @@ -74,23 +75,27 @@ def _test_fbeta_score(self, actuals, preds, res): def test_fbeta_perfect_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) + preds = tf.constant([[1, 0.9, 0.98], [0.85, 0.3, 0.2], [0.85, 1, 0.6]], + dtype=tf.float32) self._test_fbeta_score(actuals, preds, 1.0) # test for the worst score def test_fbeta_worst_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) + preds = tf.constant( + [[0.4, 0.7, 0.65], [0.5, 0.9, 0.78], [0.65, 0.78, 1]], + dtype=tf.float32) self._test_fbeta_score(actuals, preds, 0.0) # test for the random score def test_fbeta_random_score(self): actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32) - preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32) + preds = tf.constant([[0.7, 0.76, 0.98], [0.9, 1, 0.6], [1, 0.9, 0.81]], + dtype=tf.float32) - # test parameters + # Use absl parameterized test here if possible test_params = [['micro', 0.5, 0.666667], ['macro', 0.5, 0.654882], ['weighted', 0.5, 0.71380], ['micro', 2.0, 0.666667], ['macro', 2.0, 0.68253], ['weighted', 2.0, 0.66269]] @@ -104,10 +109,11 @@ def test_fbeta_random_score_none(self): [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32) preds = tf.constant( - [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]], - dtype=tf.int32) + [[0.9, 0.75, 0], [0.33, 1, 0.70], [0.79, 0.56, 0.98], + [1, 0.46, 0.67], [0.95, 0.44, 0.54], [0.77, 0.42, 1]], + dtype=tf.float32) - # test parameters + # Use absl parameterized test here if possible test_params = [[0.5, [0.71428573, 0.8333334, 1.]], [2.0, [0.90909094, 0.5555556, 1.]]] @@ -119,17 +125,22 @@ def test_keras_model(self): model = tf.keras.Sequential() model.add(layers.Dense(64, activation='relu')) model.add(layers.Dense(64, activation='relu')) - model.add(layers.Dense(1, activation='softmax')) - fb = FBetaScore(1, 'macro') + model.add(layers.Dense(2, activation='softmax')) + + fb = FBetaScore( + num_classes=2, beta=2.0, threshold=0.4, average='weighted') + model.compile( optimizer='rmsprop', loss='categorical_crossentropy', metrics=['acc', fb]) - # data preparation data = np.random.random((10, 3)) - labels = np.random.random((10, 1)) + + labels = np.random.random((10, 2)) labels = np.where(labels > 0.5, 1, 0) - model.fit(data, labels, epochs=1, batch_size=32, verbose=0) + + fitted_model = model.fit( + data, labels, epochs=2, batch_size=32, verbose=0) if __name__ == '__main__': From 01318dec33faad35508297cf41f58b60c5abf9dc Mon Sep 17 00:00:00 2001 From: saishruthi Date: Tue, 10 Sep 2019 17:35:32 -0700 Subject: [PATCH 2/2] adding threshold parameter to f1 --- tensorflow_addons/metrics/f1_test.py | 4 ++-- tensorflow_addons/metrics/f_scores.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/metrics/f1_test.py b/tensorflow_addons/metrics/f1_test.py index 92e8357d4b..48ef9e2392 100755 --- a/tensorflow_addons/metrics/f1_test.py +++ b/tensorflow_addons/metrics/f1_test.py @@ -89,7 +89,7 @@ def test_f1_random_score(self): dtype=tf.int32) preds = tf.constant([[0.4, 0.7, 1], [1, 0.8, 0], [1, 0.9, 0.8]], dtype=tf.float32) - # Use absl parameterized test here if possible + # test parameters test_params = [['micro', 0.6666667], ['macro', 0.65555555], ['weighted', 0.67777777]] @@ -105,7 +105,7 @@ def test_f1_random_score_none(self): [0.8, 0.2, 0], [0.99, 0.05, 0.05], [0, 0, 1]], dtype=tf.float32) - # Use absl parameterized test here if possible + # test parameters test_params = [[None, [0.8, 0.6666667, 1.]]] for avg, res in test_params: diff --git a/tensorflow_addons/metrics/f_scores.py b/tensorflow_addons/metrics/f_scores.py index d487714722..510dc5befa 100755 --- a/tensorflow_addons/metrics/f_scores.py +++ b/tensorflow_addons/metrics/f_scores.py @@ -179,7 +179,8 @@ def __init__(self, initializer='zeros', dtype=self.dtype) - # TO DO SSaishruthi: Add sample weight option + # TODO: Add sample_weight support, currently it is + # ignored during calculations. def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, tf.int32) y_pred = tf.cast(y_pred > self.threshold, tf.int32)