Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions tensorflow_addons/metrics/f1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@test_utils.run_all_in_graph_and_eager_modes
class F1ScoreTest(tf.test.TestCase):
def test_config(self):
f1_obj = F1Score(num_classes=3, average=None)
f1_obj = F1Score(num_classes=3, threshold=0.75, average=None)
self.assertEqual(f1_obj.name, 'f1_score')
self.assertEqual(f1_obj.dtype, tf.float32)
self.assertEqual(f1_obj.num_classes, 3)
Expand All @@ -42,7 +42,7 @@ def test_config(self):

def initialize_vars(self, average):
# initialize variables
f1_obj = F1Score(num_classes=3, average=average)
f1_obj = F1Score(num_classes=3, threshold=0.75, average=average)

self.evaluate(tf.compat.v1.variables_initializer(f1_obj.variables))

Expand Down Expand Up @@ -71,22 +71,25 @@ def _test_f1_score(self, actuals, preds, res):
def test_f1_perfect_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32)
preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.float32)
self._test_f1_score(actuals, preds, 1.0)

# test for worst f1 score
def test_f1_worst_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32)
preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=tf.float32)
self._test_f1_score(actuals, preds, 0.0)

# test for random f1 score
def test_f1_random_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32)
# Use absl parameterized test here if possible
preds = tf.constant([[0.4, 0.7, 1], [1, 0.8, 0], [1, 0.9, 0.8]],
dtype=tf.float32)
# test parameters
test_params = [['micro', 0.6666667], ['macro', 0.65555555],
['weighted', 0.67777777]]

Expand All @@ -98,11 +101,11 @@ def test_f1_random_score_none(self):
actuals = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=tf.int32)
preds = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
dtype=tf.int32)
preds = tf.constant([[0.8, 0.1, 0.1], [0, 1, 0], [0.11, 0.13, 0.76],
[0.8, 0.2, 0], [0.99, 0.05, 0.05], [0, 0, 1]],
dtype=tf.float32)

# Use absl parameterized test here if possible
# test parameters
test_params = [[None, [0.8, 0.6666667, 1.]]]

for avg, res in test_params:
Expand All @@ -113,17 +116,21 @@ def test_keras_model(self):
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(1, activation='softmax'))
fb = F1Score(1, 'macro')
model.add(layers.Dense(2, activation='softmax'))

f1 = F1Score(num_classes=2, threshold=0.6, average='weighted')

model.compile(
optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['acc', fb])
# data preparation
data = np.random.random((10, 3))
labels = np.random.random((10, 1))

labels = np.random.random((10, 2))
labels = np.where(labels > 0.5, 1, 0)
model.fit(data, labels, epochs=1, batch_size=32, verbose=0)

fitted_model = model.fit(
data, labels, epochs=1, batch_size=32, verbose=0)


if __name__ == '__main__':
Expand Down
56 changes: 33 additions & 23 deletions tensorflow_addons/metrics/f_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ class FBetaScore(Metric):

Args:
num_classes : Number of unique classes in the dataset.
threshold: Float representing the threshold for deciding whether
prediction values are 1 or 0.
average : Type of averaging to be performed on data.
Acceptable values are None, micro, macro and
weighted.
Acceptable values are None, micro, macro and
weighted.
beta : float. Determines the weight of precision and recall
in harmonic mean. Acceptable values are either a number
of float data type greater than 0.0 or a scale tensor
of dtype tf.float32.
in harmonic mean. Acceptable values are either a number
of float data type greater than 0.0 or a scale tensor
of dtype tf.float32.

Returns:
F Beta Score: float
Expand Down Expand Up @@ -82,30 +84,30 @@ class are returned.
Usage:
```python
actuals = tf.constant([[1, 1, 0],[1, 0, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 0, 0],[1, 0, 1]],
dtype=tf.int32)
dtype=tf.float32)
preds = tf.constant([[0.9, 0.2, 0.2],[0.82, 0.3, 0.85]],
dtype=tf.float32)
# F-Beta Micro
fb_score = tfa.metrics.FBetaScore(num_classes=3,
beta=0.4, average='micro')
beta=0.4, threshold = 0.8, average='micro')
fb_score.update_state(actuals, preds)
print('F1-Beta Score is: ',
fb_score.result().numpy()) # 0.6666666
# F-Beta Macro
fb_score = tfa.metrics.FBetaScore(num_classes=3,
beta=0.4, average='macro')
beta=0.4, threshold = 0.8, average='macro')
fb_score.update_state(actuals, preds)
print('F1-Beta Score is: ',
fb_score.result().numpy()) # 0.33333334
# F-Beta Weighted
fb_score = tfa.metrics.FBetaScore(num_classes=3,
beta=0.4, average='weighted')
beta=0.4, threshold = 0.8, average='weighted')
fb_score.update_state(actuals, preds)
print('F1-Beta Score is: ',
fb_score.result().numpy()) # 0.6666667
# F-Beta score for each class (average=None).
fb_score = tfa.metrics.FBetaScore(num_classes=3,
beta=0.4, average=None)
beta=0.4, threshold = 0.8, average=None)
fb_score.update_state(actuals, preds)
print('F1-Beta Score is: ',
fb_score.result().numpy()) # [1. 0. 0.]
Expand All @@ -116,10 +118,12 @@ def __init__(self,
num_classes,
average=None,
beta=1.0,
threshold=0.8,
name='fbeta_score',
dtype=tf.float32):
super(FBetaScore, self).__init__(name=name)
self.num_classes = num_classes
self.threshold = threshold
# type check
if not isinstance(beta, float) and beta.dtype != tf.float32:
raise TypeError("The value of beta should be float")
Expand Down Expand Up @@ -179,7 +183,7 @@ def __init__(self,
# ignored during calculations.
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(y_pred, tf.int32)
y_pred = tf.cast(y_pred > self.threshold, tf.int32)

# true positive
self.true_positives.assign_add(
Expand Down Expand Up @@ -234,6 +238,7 @@ def get_config(self):

config = {
"num_classes": self.num_classes,
"threshold": self.threshold,
"average": self.average,
"beta": self.beta,
}
Expand Down Expand Up @@ -265,6 +270,8 @@ class F1Score(FBetaScore):

Args:
num_classes : Number of unique classes in the dataset.
threshold: Float representing the threshold for deciding whether
prediction values are 1 or 0.
average : Type of averaging to be performed on data.
Acceptable values are `None`, `micro`, `macro` and
`weighted`.
Expand Down Expand Up @@ -307,41 +314,44 @@ class are returned.
```python
actuals = tf.constant([[1, 1, 0],[1, 0, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 0, 0],[1, 0, 1]],
dtype=tf.int32)
preds = tf.constant([[0.9, 0.4, 0.45],[1, 0.2, 0.87]],
dtype=tf.float32)
# F1 Micro
output = tfa.metrics.F1Score(num_classes=3,
output = tfa.metrics.F1Score(num_classes=3, threshold=0.8,
average='micro')
output.update_state(actuals, preds)
print('F1 Micro score is: ',
output.result().numpy()) # 0.6666667
# F1 Macro
output = tfa.metrics.F1Score(num_classes=3,
output = tfa.metrics.F1Score(num_classes=3, threshold=0.8,
average='macro')
output.update_state(actuals, preds)
print('F1 Macro score is: ',
output.result().numpy()) # 0.33333334
# F1 weighted
output = tfa.metrics.F1Score(num_classes=3,
output = tfa.metrics.F1Score(num_classes=3, threshold=0.8,
average='weighted')
output.update_state(actuals, preds)
print('F1 Weighted score is: ',
output.result().numpy()) # 0.6666667
# F1 score for each class (average=None).
output = tfa.metrics.F1Score(num_classes=3)
output = tfa.metrics.F1Score(num_classes=3, threshold=0.8,
average=None)
output.update_state(actuals, preds)
print('F1 score is: ',
output.result().numpy()) # [1. 0. 0.]
```
"""

def __init__(self, num_classes, average, name='f1_score',
def __init__(self,
num_classes,
threshold,
average,
name='f1_score',
dtype=tf.float32):
super(F1Score, self).__init__(
num_classes, average, 1.0, name=name, dtype=dtype)
num_classes, average, 1.0, threshold, name=name, dtype=dtype)

# TODO: Add sample_weight support, currently it is
# ignored during calculations.
def get_config(self):
base_config = super(F1Score, self).get_config()
del base_config["beta"]
Expand Down
37 changes: 24 additions & 13 deletions tensorflow_addons/metrics/fbeta_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
@test_utils.run_all_in_graph_and_eager_modes
class FBetaScoreTest(tf.test.TestCase):
def test_config(self):
fbeta_obj = FBetaScore(num_classes=3, beta=0.5, average=None)
fbeta_obj = FBetaScore(
num_classes=3, beta=0.5, threshold=0.8, average=None)
self.assertEqual(fbeta_obj.beta, 0.5)
self.assertEqual(fbeta_obj.average, None)
self.assertEqual(fbeta_obj.num_classes, 3)
Expand Down Expand Up @@ -74,23 +75,27 @@ def _test_fbeta_score(self, actuals, preds, res):
def test_fbeta_perfect_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32)
preds = tf.constant([[1, 0.9, 0.98], [0.85, 0.3, 0.2], [0.85, 1, 0.6]],
dtype=tf.float32)
self._test_fbeta_score(actuals, preds, 1.0)

# test for the worst score
def test_fbeta_worst_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32)
preds = tf.constant(
[[0.4, 0.7, 0.65], [0.5, 0.9, 0.78], [0.65, 0.78, 1]],
dtype=tf.float32)
self._test_fbeta_score(actuals, preds, 0.0)

# test for the random score
def test_fbeta_random_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32)
preds = tf.constant([[0.7, 0.76, 0.98], [0.9, 1, 0.6], [1, 0.9, 0.81]],
dtype=tf.float32)

# test parameters
# Use absl parameterized test here if possible
test_params = [['micro', 0.5, 0.666667], ['macro', 0.5, 0.654882],
['weighted', 0.5, 0.71380], ['micro', 2.0, 0.666667],
['macro', 2.0, 0.68253], ['weighted', 2.0, 0.66269]]
Expand All @@ -104,10 +109,11 @@ def test_fbeta_random_score_none(self):
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=tf.int32)
preds = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
dtype=tf.int32)
[[0.9, 0.75, 0], [0.33, 1, 0.70], [0.79, 0.56, 0.98],
[1, 0.46, 0.67], [0.95, 0.44, 0.54], [0.77, 0.42, 1]],
dtype=tf.float32)

# test parameters
# Use absl parameterized test here if possible
test_params = [[0.5, [0.71428573, 0.8333334, 1.]],
[2.0, [0.90909094, 0.5555556, 1.]]]

Expand All @@ -119,17 +125,22 @@ def test_keras_model(self):
model = tf.keras.Sequential()
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(1, activation='softmax'))
fb = FBetaScore(1, 'macro')
model.add(layers.Dense(2, activation='softmax'))

fb = FBetaScore(
num_classes=2, beta=2.0, threshold=0.4, average='weighted')

model.compile(
optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['acc', fb])
# data preparation
data = np.random.random((10, 3))
labels = np.random.random((10, 1))

labels = np.random.random((10, 2))
labels = np.where(labels > 0.5, 1, 0)
model.fit(data, labels, epochs=1, batch_size=32, verbose=0)

fitted_model = model.fit(
data, labels, epochs=2, batch_size=32, verbose=0)


if __name__ == '__main__':
Expand Down