Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/metrics/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"cohens_kappa.py",
"r_square.py",
],
srcs_version = "PY2AND3",
deps = [
Expand All @@ -26,3 +27,16 @@ py_test(
":metrics",
],
)

py_test(
name = "r_square_test",
size = "small",
srcs = [
"r_square_test.py",
],
main = "r_square_test.py",
srcs_version = "PY2AND3",
deps = [
":metrics",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| cohens_kappa| Aakash Nain | [email protected]|
| r_square| Saishruthi Swaminathan| [email protected]|

## Contents
| Submodule | Metric | Reference |
|:----------------------- |:-------------------|:---------------|
| cohens_kappa| CohenKappa|[Cohen's Kappa](https://en.wikipedia.org/wiki/Cohen%27s_kappa)|
| r_square| RSquare|[R-Sqaure](https://en.wikipedia.org/wiki/Coefficient_of_determination)|


## Contribution Guidelines
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from __future__ import division
from __future__ import print_function

from tensorflow_addons.metrics.cohens_kappa import CohenKappa
from tensorflow_addons.metrics.cohens_kappa import CohenKappa
from tensorflow_addons.metrics.r_square import RSquare
73 changes: 73 additions & 0 deletions tensorflow_addons/metrics/r_square.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements R^2 scores."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.keras.metrics import Metric


class RSquare(Metric):
"""Compute R^2 score.

This is also called as coefficient of determination.
It tells how close are data to the fitted regression line.

- Highest score can be 1.0 and it indicates that the predictors
perfectly accounts for variation in the target.
- Score 0.0 indicates that the predictors do not
account for variation in the target.
- It can also be negative if the model is worse.

Usage:
```python
actuals = tf.constant([1, 4, 3], dtype=tf.float32)
preds = tf.constant([2, 4, 4], dtype=tf.float32)
result = tf.keras.metrics.RSquare()
result.update_state(actuals, preds)
print('R^2 score is: ', r1.result().numpy()) # 0.57142866
```
"""

def __init__(self, name='r_square', dtype=tf.float32):
super(RSquare, self).__init__(name=name, dtype=dtype)
self.squared_sum = self.add_weight("squared_sum", initializer="zeros")
self.sum = self.add_weight("sum", initializer="zeros")
self.res = self.add_weight("residual", initializer="zeros")
self.count = self.add_weight("count", initializer="zeros")

def update_state(self, y_true, y_pred):
y_true = tf.convert_to_tensor(y_true, tf.float32)
y_pred = tf.convert_to_tensor(y_pred, tf.float32)
self.squared_sum.assign_add(tf.reduce_sum(y_true**2))
self.sum.assign_add(tf.reduce_sum(y_true))
self.res.assign_add(
tf.reduce_sum(tf.square(tf.subtract(y_true, y_pred))))
self.count.assign_add(tf.cast(tf.shape(y_true)[0], tf.float32))

def result(self):
mean = self.sum / self.count
total = self.squared_sum - 2 * self.sum * mean + self.count * mean**2
return 1 - (self.res / total)

def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.squared_sum.assign(0.0)
self.sum.assign(0.0)
self.res.assign(0.0)
self.count.assign(0.0)
85 changes: 85 additions & 0 deletions tensorflow_addons/metrics/r_square_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for R-Square Metric."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.metrics import RSquare


class RSquareTest(tf.test.TestCase):
def test_config(self):
r2_obj = RSquare(name='r_square')
self.assertEqual(r2_obj.name, 'r_square')
self.assertEqual(r2_obj.dtype, tf.float32)
# Check save and restore config
r2_obj2 = RSquare.from_config(r2_obj.get_config())
self.assertEqual(r2_obj2.name, 'r_square')
self.assertEqual(r2_obj2.dtype, tf.float32)

def initialize_vars(self):
r2_obj = RSquare()
self.evaluate(tf.compat.v1.variables_initializer(r2_obj.variables))
return r2_obj

def update_obj_states(self, obj, actuals, preds):
update_op = obj.update_state(actuals, preds)
self.evaluate(update_op)

def check_results(self, obj, value):
self.assertAllClose(value, self.evaluate(obj.result()), atol=1e-5)

def test_r2_perfect_score(self):
actuals = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
preds = tf.constant([100, 700, 40, 5.7], dtype=tf.float32)
actuals = tf.constant(actuals, dtype=tf.float32)
preds = tf.constant(preds, dtype=tf.float32)
# Initialize
r2_obj = self.initialize_vars()
# Update
self.update_obj_states(r2_obj, actuals, preds)
# Check results
self.check_results(r2_obj, 1.0)

def test_r2_worst_score(self):
actuals = tf.constant([10, 600, 4, 9.77], dtype=tf.float32)
preds = tf.constant([1, 70, 40, 5.7], dtype=tf.float32)
actuals = tf.constant(actuals, dtype=tf.float32)
preds = tf.constant(preds, dtype=tf.float32)
# Initialize
r2_obj = self.initialize_vars()
# Update
self.update_obj_states(r2_obj, actuals, preds)
# Check results
self.check_results(r2_obj, -0.073607)

def test_r2_random_score(self):
actuals = tf.constant([10, 600, 3, 9.77], dtype=tf.float32)
preds = tf.constant([1, 340, 40, 5.7], dtype=tf.float32)
actuals = tf.constant(actuals, dtype=tf.float32)
preds = tf.constant(preds, dtype=tf.float32)
# Initialize
r2_obj = self.initialize_vars()
# Update
self.update_obj_states(r2_obj, actuals, preds)
# Check results
self.check_results(r2_obj, 0.7376327)


if __name__ == '__main__':
tf.test.main()