Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"cohens_kappa.py",
"f1_scores.py",
"r_square.py",
],
srcs_version = "PY2AND3",
Expand Down Expand Up @@ -40,3 +41,16 @@ py_test(
":metrics",
],
)

py_test(
name = "f1_test",
size = "small",
srcs = [
"f1_test.py",
],
main = "f1_test.py",
srcs_version = "PY2AND3",
deps = [
":metrics",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| cohens_kappa| Aakash Nain | [email protected]|
| f1_scores| Saishruthi Swaminathan | [email protected]|
| r_square| Saishruthi Swaminathan| [email protected]|

## Contents
| Submodule | Metric | Reference |
|:----------------------- |:-------------------|:---------------|
| cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)|
| f1_scores| F1 micro, macro and weighted| [F1 Score](https://en.wikipedia.org/wiki/F1_score)|
| r_square| RSquare|[R-Sqaure](https://en.wikipedia.org/wiki/Coefficient_of_determination)|


Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/metrics/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
from __future__ import print_function

from tensorflow_addons.metrics.cohens_kappa import CohenKappa
from tensorflow_addons.metrics.f1_scores import F1Score
from tensorflow_addons.metrics.r_square import RSquare
208 changes: 208 additions & 0 deletions tensorflow_addons/metrics/f1_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements F1 scores."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.keras.metrics import Metric
import numpy as np


class F1Score(Metric):
"""Calculates F1 micro, macro or weighted based on the user's choice.

F1 score is the weighted average of precision and
recall. Output range is [0, 1]. This works for both
multi-class and multi-label classification.


Args:
num_classes : Number of unique classes in the dataset.
average : Type of averaging to be performed on data.
Acceptable values are None, micro, macro and
weighted.
Default value is None.

Returns:
F1 score: float

Raises:
ValueError: If the `average` has values other than
[None, micro, macro. weighted].

`average` parameter behavior:

1. If `None` is specified as an input, scores for each
class are returned.

2. If `micro` is specified, metrics like true positivies,
false positives and false negatives are computed
globally.

3. If `macro` is specified, metrics like true positivies,
false positives and false negatives are computed for
each class and their unweighted mean is returned.
Imbalance in dataset is not taken into account for
calculating the score

4. If `weighted` is specified, metrics are computed for
each class and returns the mean weighted by the
number of true instances in each class taking data
imbalance into account.

Usage:
```python
actuals = tf.constant([[1, 1, 0],[1, 0, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 0, 0],[1, 0, 1]],
dtype=tf.int32)
output = tf.keras.metrics.F1Score(num_classes=3,
average='micro')
output.update_state(actuals, predictions)
print('F1 Micro score is: ',
output.result().numpy()) # 0.6666667
```
"""

def __init__(self,
num_classes,
average=None,
name='f1_score',
dtype=tf.float32):
super(F1Score, self).__init__(name=name)
self.num_classes = num_classes
if average not in (None, 'micro', 'macro', 'weighted'):
raise ValueError("Unknown average type. Acceptable values "
"are: [micro, macro, weighted]")
else:
self.average = average
if self.average == 'micro':
self.axis = None
else:
self.axis = 0
if self.average == 'micro':
self.true_positives = self.add_weight(
'true_positives',
shape=[],
initializer='zeros',
dtype=tf.float32)
self.false_positives = self.add_weight(
'false_positives',
shape=[],
initializer='zeros',
dtype=tf.float32)
self.false_negatives = self.add_weight(
'false_negatives',
shape=[],
initializer='zeros',
dtype=tf.float32)
else:
self.true_positives = self.add_weight(
'true_positives',
shape=[self.num_classes],
initializer='zeros',
dtype=tf.float32)
self.false_positives = self.add_weight(
'false_positives',
shape=[self.num_classes],
initializer='zeros',
dtype=tf.float32)
self.false_negatives = self.add_weight(
'false_negatives',
shape=[self.num_classes],
initializer='zeros',
dtype=tf.float32)
self.weights_intermediate = self.add_weight(
'weights',
shape=[self.num_classes],
initializer='zeros',
dtype=tf.float32)

def update_state(self, y_true, y_pred):
y_true = tf.cast(y_true, tf.int32)
y_pred = tf.cast(y_pred, tf.int32)

# true positive
self.true_positives.assign_add(
tf.cast(
tf.math.count_nonzero(y_pred * y_true, axis=self.axis),
tf.float32))
# false positive
self.false_positives.assign_add(
tf.cast(
tf.math.count_nonzero(y_pred * (y_true - 1), axis=self.axis),
tf.float32))
# false negative
self.false_negatives.assign_add(
tf.cast(
tf.math.count_nonzero((y_pred - 1) * y_true, axis=self.axis),
tf.float32))
if self.average == 'weighted':
# variable to hold intermediate weights
self.weights_intermediate.assign_add(
tf.cast(tf.reduce_sum(y_true, axis=self.axis), tf.float32))

def result(self):
p_sum = tf.cast(self.true_positives + self.false_positives, tf.float32)
# calculate precision
precision = tf.math.divide_no_nan(self.true_positives, p_sum)

r_sum = tf.cast(self.true_positives + self.false_negatives, tf.float32)
# calculate recall
recall = tf.math.divide_no_nan(self.true_positives, r_sum)

mul_value = 2 * precision * recall
add_value = precision + recall
f1_int = tf.math.divide_no_nan(mul_value, add_value)
# f1 score
if self.average is not None:
f1_score = tf.reduce_mean(f1_int)
else:
f1_score = f1_int
# condition for weighted f1 score
if self.average == 'weighted':
f1_int_weights = tf.math.divide_no_nan(
self.weights_intermediate,
tf.reduce_sum(self.weights_intermediate))
# weighted f1 score calculation
f1_score = tf.reduce_sum(f1_int * f1_int_weights)

return f1_score

def get_config(self):
"""Returns the serializable config of the metric."""

config = {
"num_classes": self.num_classes,
"average": self.average,
}
base_config = super(F1Score, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def reset_states(self):
# reset state of the variables to zero
if self.average == 'micro':
self.true_positives.assign(0)
self.false_positives.assign(0)
self.false_negatives.assign(0)
else:
self.true_positives.assign(np.zeros(self.num_classes), np.float32)
self.false_positives.assign(np.zeros(self.num_classes), np.float32)
self.false_negatives.assign(np.zeros(self.num_classes), np.float32)
self.weights_intermediate.assign(
np.zeros(self.num_classes), np.float32)
127 changes: 127 additions & 0 deletions tensorflow_addons/metrics/f1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests F1 metrics."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.metrics import F1Score
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class F1ScoreTest(tf.test.TestCase):
def test_config(self):
f1_obj = F1Score(name='f1_score', num_classes=3)
self.assertEqual(f1_obj.name, 'f1_score')
self.assertEqual(f1_obj.dtype, tf.float32)
self.assertEqual(f1_obj.num_classes, 3)
# Check save and restore config
f1_obj2 = F1Score.from_config(f1_obj.get_config())
self.assertEqual(f1_obj2.name, 'f1_score')
self.assertEqual(f1_obj2.dtype, tf.float32)
self.assertEqual(f1_obj2.num_classes, 3)

def initialize_vars(self):
f1_micro = F1Score(num_classes=3, average='micro')
f1_macro = F1Score(num_classes=3, average='macro')
f1_weighted = F1Score(num_classes=3, average='weighted')

self.evaluate(tf.compat.v1.variables_initializer(f1_micro.variables))
self.evaluate(tf.compat.v1.variables_initializer(f1_macro.variables))
self.evaluate(
tf.compat.v1.variables_initializer(f1_weighted.variables))
return f1_micro, f1_macro, f1_weighted

def initialize_vars_none(self):
f1_none = F1Score(num_classes=3, average=None)

self.evaluate(tf.compat.v1.variables_initializer(f1_none.variables))
return f1_none

def update_obj_states(self, f1_micro, f1_macro, f1_weighted, actuals,
preds):
update_micro = f1_micro.update_state(actuals, preds)
update_macro = f1_macro.update_state(actuals, preds)
update_weighted = f1_weighted.update_state(actuals, preds)
self.evaluate(update_micro)
self.evaluate(update_macro)
self.evaluate(update_weighted)

def update_obj_states_none(self, f1_none, actuals, preds):
update_none = f1_none.update_state(actuals, preds)
self.evaluate(update_none)

def check_results(self, obj, value):
self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5)

def test_f1_perfect_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]], dtype=tf.int32)
# Initialize
f1_micro, f1_macro, f1_weighted = self.initialize_vars()
# Update
self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds)
# Check results
self.check_results(f1_micro, 1.0)
self.check_results(f1_macro, 1.0)
self.check_results(f1_weighted, 1.0)

def test_f1_worst_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=tf.int32)
# Initialize
f1_micro, f1_macro, f1_weighted = self.initialize_vars()
# Update
self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds)
# Check results
self.check_results(f1_micro, 0.0)
self.check_results(f1_macro, 0.0)
self.check_results(f1_weighted, 0.0)

def test_f1_random_score(self):
actuals = tf.constant([[1, 1, 1], [1, 0, 0], [1, 1, 0]],
dtype=tf.int32)
preds = tf.constant([[0, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=tf.int32)
# Initialize
f1_micro, f1_macro, f1_weighted = self.initialize_vars()
# Update
self.update_obj_states(f1_micro, f1_macro, f1_weighted, actuals, preds)
# Check results
self.check_results(f1_micro, 0.6666666)
self.check_results(f1_macro, 0.6555555)
self.check_results(f1_weighted, 0.6777777)

def test_f1_none_score(self):
actuals = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1]],
dtype=tf.int32)
preds = tf.constant(
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
dtype=tf.int32)
# Initialize
f1_none = self.initialize_vars_none()
# Update
self.update_obj_states_none(f1_none, actuals, preds)
# Check results
self.check_results(f1_none, [0.8, 0.6666667, 1.])


if __name__ == '__main__':
tf.test.main()