diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py index 8a2d6b7464..61a3b3f3bc 100755 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -31,4 +31,9 @@ from tensorflow_addons.metrics.r_square import RSquare from tensorflow_addons.metrics.geometric_mean import GeometricMean from tensorflow_addons.metrics.harmonic_mean import HarmonicMean -from tensorflow_addons.metrics.kendalls_tau import KendallsTau +from tensorflow_addons.metrics.streaming_correlations import ( + KendallsTauB, + KendallsTauC, + PearsonsCorrelation, + SpearmansRank, +) diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/kendalls_tau.py deleted file mode 100644 index 3f1a69f392..0000000000 --- a/tensorflow_addons/metrics/kendalls_tau.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Approximate Kendall's Tau-b Metric.""" - -import tensorflow as tf -from tensorflow.keras.metrics import Metric -from tensorflow_addons.utils.types import AcceptableDTypes - -from typeguard import typechecked - - -@tf.keras.utils.register_keras_serializable(package="Addons") -class KendallsTau(Metric): - """Computes Kendall's Tau-b Rank Correlation Coefficient. - - A measure of ordinal similarity between equal length sequences - of values, with allowances for ties. - - Based on the algorithm of Wei Xiao https://arxiv.org/abs/1712.01521. - - Usage: - - ```python - actuals = tf.constant([12, 2, 1, 12, 2], dtype=np.int32) - preds = tf.constant([1, 4, 7, 1, 0], dtype=np.int32) - - m = tfa.metrics.KendallsTau(0, 13) - m.update_state(actuals, preds) - print('Final result: ', m.result().numpy()) # Result: -0.4714045 - ``` - - """ - - @typechecked - def __init__( - self, - actual_min: float = 0.0, - actual_max: float = 1.0, - preds_min: float = 0.0, - preds_max: float = 1.0, - actual_cutpoints: int = 100, - preds_cutpoints: int = 100, - name: str = "kendalls_tau", - dtype: AcceptableDTypes = None, - ): - """Creates a `KendallsTau` instance. - - Args: - actual_min: the inclusive lower bound on values from actual. - actual_max: the exclusive upper bound on values from actual. - preds_min: the inclusive lower bound on values from preds. - preds_max: the exclusive upper bound on values from preds. - actual_cutpoints: the number of divisions to create in actual range, - defaults to 100. - preds_cutpoints: the number of divisions to create in preds range, - defaults to 100. - name: (optional) String name of the metric instance - dtype: (optional) Data type of the metric result. Defaults to `None` - """ - super().__init__(name=name, dtype=dtype) - self.actual_min = actual_min - self.actual_max = actual_max - self.preds_min = preds_min - self.preds_max = preds_max - self.actual_cutpoints = actual_cutpoints - self.preds_cutpoints = preds_cutpoints - self.actual_cuts = tf.linspace( - tf.cast(self.actual_min, tf.float32), - tf.cast(self.actual_max, tf.float32), - self.actual_cutpoints - 1, - ) - self.preds_cuts = tf.linspace( - tf.cast(self.preds_min, tf.float32), - tf.cast(self.preds_max, tf.float32), - self.preds_cutpoints - 1, - ) - self.m = self.add_weight( - "m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64 - ) - self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64) - self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64) - self.n = self.add_weight("n", (), dtype=tf.int64) - - def update_state(self, y_true, y_pred, sample_weight=None): - """Accumulates ranks. - - Args: - y_true: actual rank values. - y_pred: predicted rank values. - sample_weight (optional): Ignored. - - Returns: - Update op. - """ - i = tf.searchsorted( - self.actual_cuts, - tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), - ) - j = tf.searchsorted( - self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype) - ) - - m = tf.sparse.from_dense(self.m) - nrow = tf.sparse.from_dense(self.nrow) - ncol = tf.sparse.from_dense(self.ncol) - - k = 0 - while k < tf.shape(i)[0]: - m = tf.sparse.add( - m, - tf.SparseTensor( - [[i[k], j[k]]], - tf.cast([1], dtype=m.dtype), - self.m.shape, - ), - ) - nrow = tf.sparse.add( - nrow, - tf.SparseTensor( - [[i[k]]], - tf.cast([1], dtype=nrow.dtype), - self.nrow.shape, - ), - ) - ncol = tf.sparse.add( - ncol, - tf.SparseTensor( - [[j[k]]], - tf.cast([1], dtype=ncol.dtype), - self.ncol.shape, - ), - ) - k += 1 - - self.n.assign_add(tf.cast(k, tf.int64)) - self.m.assign(tf.sparse.to_dense(m)) - self.nrow.assign(tf.sparse.to_dense(nrow)) - self.ncol.assign(tf.sparse.to_dense(ncol)) - - def result(self): - m = tf.cast(self.m, tf.float32) - n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1) - # Number of concordant pairs. - p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:])) - sum_m_squard = tf.math.reduce_sum(tf.math.square(m)) - # Ties in x. - t = ( - tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32) - - sum_m_squard - ) / 2.0 - # Ties in y. - u = ( - tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32) - - sum_m_squard - ) / 2.0 - # Ties in both. - b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0 - # Number of discordant pairs. - n = tf.cast(self.n, tf.float32) - q = (n - 1.0) * n / 2.0 - p - t - u - b - return (p - q) / tf.math.sqrt((p + q + t) * (p + q + u)) - - def get_config(self): - """Returns the serializable config of the metric.""" - - config = { - "actual_min": self.actual_min, - "actual_max": self.actual_max, - "preds_min": self.preds_min, - "preds_max": self.preds_max, - "actual_cutpoints": self.actual_cutpoints, - "preds_cutpoints": self.preds_cutpoints, - } - base_config = super().get_config() - return {**base_config, **config} - - def reset_state(self): - """Resets all of the metric state variables.""" - - self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64)) - self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64)) - self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64)) - self.n.assign(0) - - def reset_states(self): - # Backwards compatibility alias of `reset_state`. New classes should - # only implement `reset_state`. - # Required in Tensorflow < 2.5.0 - return self.reset_state() diff --git a/tensorflow_addons/metrics/streaming_correlations.py b/tensorflow_addons/metrics/streaming_correlations.py new file mode 100644 index 0000000000..330205f189 --- /dev/null +++ b/tensorflow_addons/metrics/streaming_correlations.py @@ -0,0 +1,336 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Approximate Pearson's, Spearman's, Kendall's Tau-b/c correlations based +on the algorithm of Wei Xiao https://arxiv.org/abs/1712.01521.""" +from abc import abstractmethod + +import numpy as np +import tensorflow as tf +from tensorflow.keras import backend +from tensorflow.keras.metrics import Metric +from tensorflow_addons.utils.types import AcceptableDTypes +from typeguard import typechecked + + +class CorrelationBase(Metric): + """Base class for streaming correlation metrics. + + Based on https://arxiv.org/abs/1712.01521. + + It stores and updates the joint and marginal histograms of (`y_true`, `y_pred`). + + The concrete classes estimate the different correlation metrics + based on those histograms. + """ + + @typechecked + def __init__( + self, + actual_min: float = 0.0, + actual_max: float = 1.0, + preds_min: float = 0.0, + preds_max: float = 1.0, + actual_cutpoints: int = 100, + preds_cutpoints: int = 100, + name: str = None, + dtype: AcceptableDTypes = None, + ): + """Creates a `CorrelationBase` instance. + + Args: + actual_min: the inclusive lower bound on values from actual. + actual_max: the exclusive upper bound on values from actual. + preds_min: the inclusive lower bound on values from preds. + preds_max: the exclusive upper bound on values from preds. + actual_cutpoints: the number of divisions to create in actual range, + defaults to 100. + preds_cutpoints: the number of divisions to create in preds range, + defaults to 100. + name: (optional) String name of the metric instance + dtype: (optional) Data type of the metric result. Defaults to `None` + """ + super().__init__(name=name, dtype=dtype) + self.actual_min = actual_min + self.actual_max = actual_max + self.preds_min = preds_min + self.preds_max = preds_max + self.actual_cutpoints = actual_cutpoints + self.preds_cutpoints = preds_cutpoints + actual_cuts = np.linspace( + tf.cast(self.actual_min, tf.float32), + tf.cast(self.actual_max, tf.float32), + self.actual_cutpoints, + ) + actual_cuts[-1] += backend.epsilon() + preds_cuts = np.linspace( + tf.cast(self.preds_min, tf.float32), + tf.cast(self.preds_max, tf.float32), + self.preds_cutpoints, + ) + preds_cuts[-1] += backend.epsilon() + self.actual_cuts = tf.convert_to_tensor(actual_cuts, tf.float32) + self.preds_cuts = tf.convert_to_tensor(preds_cuts, tf.float32) + self.m = self.add_weight( + "m", (self.actual_cutpoints - 1, self.preds_cutpoints - 1), dtype=tf.int64 + ) + self.nrow = self.add_weight("nrow", (self.actual_cutpoints - 1), dtype=tf.int64) + self.ncol = self.add_weight("ncol", (self.preds_cutpoints - 1), dtype=tf.int64) + self.n = self.add_weight("n", (), dtype=tf.int64) + + def update_state(self, y_true, y_pred, sample_weight=None): + """Updates `m`, `nrow`, `ncol` respectively the joint and + marginal histograms of (`y_true`, `y_pred`) + """ + + y_true = tf.clip_by_value(y_true, self.actual_min, self.actual_max) + y_pred = tf.clip_by_value(y_pred, self.preds_min, self.preds_max) + + i = ( + tf.searchsorted( + self.actual_cuts, + tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), + side="right", + ) + - 1 + ) + j = ( + tf.searchsorted( + self.preds_cuts, + tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype), + side="right", + ) + - 1 + ) + + m = tf.sparse.from_dense(self.m) + nrow = tf.sparse.from_dense(self.nrow) + ncol = tf.sparse.from_dense(self.ncol) + + k = 0 + while k < tf.shape(i)[0]: + m = tf.sparse.add( + m, + tf.SparseTensor( + [[i[k], j[k]]], + tf.cast([1], dtype=m.dtype), + self.m.shape, + ), + ) + nrow = tf.sparse.add( + nrow, + tf.SparseTensor( + [[i[k]]], + tf.cast([1], dtype=nrow.dtype), + self.nrow.shape, + ), + ) + ncol = tf.sparse.add( + ncol, + tf.SparseTensor( + [[j[k]]], + tf.cast([1], dtype=ncol.dtype), + self.ncol.shape, + ), + ) + k += 1 + + self.n.assign_add(tf.cast(k, tf.int64)) + self.m.assign(tf.sparse.to_dense(m)) + self.nrow.assign(tf.sparse.to_dense(nrow)) + self.ncol.assign(tf.sparse.to_dense(ncol)) + + @abstractmethod + def result(self): + pass + + def get_config(self): + """Returns the serializable config of the metric.""" + + config = { + "actual_min": self.actual_min, + "actual_max": self.actual_max, + "preds_min": self.preds_min, + "preds_max": self.preds_max, + "actual_cutpoints": self.actual_cutpoints, + "preds_cutpoints": self.preds_cutpoints, + } + base_config = super().get_config() + return {**base_config, **config} + + def reset_state(self): + """Resets all of the metric state variables.""" + + self.m.assign( + tf.zeros((self.actual_cutpoints - 1, self.preds_cutpoints - 1), tf.int64) + ) + self.nrow.assign(tf.zeros((self.actual_cutpoints - 1), tf.int64)) + self.ncol.assign(tf.zeros((self.preds_cutpoints - 1), tf.int64)) + self.n.assign(0) + + def reset_states(self): + # Backwards compatibility alias of `reset_state`. New classes should + # only implement `reset_state`. + # Required in Tensorflow < 2.5.0 + return self.reset_state() + + +class KendallsTauBase(CorrelationBase): + """Base class for kendall's tau metrics.""" + + def _compute_variables(self): + """Compute a tuple containing the concordant pairs, discordant pairs, + ties in `y_true` and `y_pred`. + + Returns: + A tuple + """ + m = tf.cast(self.m, tf.float32) + n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1) + # Number of concordant pairs. + p = tf.math.reduce_sum(tf.multiply(n_cap[:-1, :-1], m[1:, 1:])) + sum_m_squard = tf.math.reduce_sum(tf.math.square(m)) + # Ties in x. + t = ( + tf.cast(tf.math.reduce_sum(tf.math.square(self.nrow)), tf.float32) + - sum_m_squard + ) / 2.0 + # Ties in y. + u = ( + tf.cast(tf.math.reduce_sum(tf.math.square(self.ncol)), tf.float32) + - sum_m_squard + ) / 2.0 + # Ties in both. + b = tf.math.reduce_sum(tf.multiply(m, (m - 1.0))) / 2.0 + # Number of discordant pairs. + n = tf.cast(self.n, tf.float32) + q = (n - 1.0) * n / 2.0 - p - t - u - b + return p, q, t, u + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class KendallsTauB(KendallsTauBase): + """Computes Kendall's Tau-b Rank Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.KendallsTauB(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.47140455 + """ + + def result(self): + p, q, t, u = self._compute_variables() + return (p - q) / tf.math.sqrt((p + q + t) * (p + q + u)) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class KendallsTauC(KendallsTauBase): + """Computes Kendall's Tau-c Rank Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.KendallsTauC(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.48000002 + """ + + def result(self): + p, q, _, _ = self._compute_variables() + n = tf.cast(self.n, tf.float32) + non_zeros_col = tf.math.count_nonzero(self.ncol) + non_zeros_row = tf.math.count_nonzero(self.nrow) + m = tf.cast(tf.minimum(non_zeros_col, non_zeros_row), tf.float32) + return 2 * (p - q) / (tf.square(n) * (m - 1) / m) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class SpearmansRank(CorrelationBase): + """Computes Spearman's Rank Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.SpearmansRank(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.54073805 + """ + + def result(self): + nrow = tf.cast(self.nrow, tf.float32) + ncol = tf.cast(self.ncol, tf.float32) + n = tf.cast(self.n, tf.float32) + + nrow_ = tf.where(nrow > 0, nrow, -1.0) + rrow = tf.pad(tf.cumsum(nrow)[:-1], [[1, 0]]) + (nrow_ - n) / 2 + ncol_ = tf.where(ncol > 0, ncol, -1.0) + rcol = tf.pad(tf.cumsum(ncol)[:-1], [[1, 0]]) + (ncol_ - n) / 2 + + rrow = rrow / tf.math.sqrt(tf.reduce_sum(nrow * tf.square(rrow))) + rcol = rcol / tf.math.sqrt(tf.reduce_sum(ncol * tf.square(rcol))) + + m = tf.cast(self.m, tf.float32) + corr = tf.matmul(tf.expand_dims(rrow, axis=0), m) + corr = tf.matmul(corr, tf.expand_dims(rcol, axis=1)) + return tf.squeeze(corr) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class PearsonsCorrelation(CorrelationBase): + """Computes Pearsons's Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.PearsonsCorrelation(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.5618297 + """ + + def result(self): + ncol = tf.cast(self.ncol, tf.float32) + nrow = tf.cast(self.nrow, tf.float32) + m = tf.cast(self.m, tf.float32) + n = tf.cast(self.n, tf.float32) + + col_bins = (self.preds_cuts[1:] - self.preds_cuts[:-1]) / 2.0 + self.preds_cuts[ + :-1 + ] + row_bins = ( + self.actual_cuts[1:] - self.actual_cuts[:-1] + ) / 2.0 + self.actual_cuts[:-1] + + n_col = tf.reduce_sum(ncol) + n_row = tf.reduce_sum(nrow) + col_mean = tf.reduce_sum(ncol * col_bins) / n_col + row_mean = tf.reduce_sum(nrow * row_bins) / n_row + + col_var = tf.reduce_sum(ncol * tf.square(col_bins)) - n_col * tf.square( + col_mean + ) + row_var = tf.reduce_sum(nrow * tf.square(row_bins)) - n_row * tf.square( + row_mean + ) + + joint_product = m * tf.expand_dims(row_bins, axis=1) * col_bins + + corr = tf.reduce_sum(joint_product) - n * col_mean * row_mean + + return corr / tf.sqrt(col_var * row_var) diff --git a/tensorflow_addons/metrics/tests/kendalls_tau_test.py b/tensorflow_addons/metrics/tests/kendalls_tau_test.py deleted file mode 100644 index 4121c64b5e..0000000000 --- a/tensorflow_addons/metrics/tests/kendalls_tau_test.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Kendall's Tau-b Metric.""" -import pytest -import numpy as np -import tensorflow as tf -from scipy import stats -from tensorflow_addons.metrics import KendallsTau -from tensorflow_addons.testing.serialization import check_metric_serialization - - -def test_config(): - kp_obj = KendallsTau(name="kendalls_tau") - assert kp_obj.name == "kendalls_tau" - assert kp_obj.dtype == tf.float32 - assert kp_obj.actual_min == 0.0 - assert kp_obj.actual_max == 1.0 - - # Check save and restore config - kp_obj2 = KendallsTau.from_config(kp_obj.get_config()) - assert kp_obj2.name == "kendalls_tau" - assert kp_obj2.dtype == tf.float32 - assert kp_obj2.actual_min == 0.0 - assert kp_obj2.actual_max == 1.0 - - -def test_scoring_with_ties(): - actuals = [12, 2, 1, 12, 2] - preds = [1, 4, 7, 1, 0] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.int32) - - metric = KendallsTau(0, 13, 0, 8) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), stats.kendalltau(actuals, preds)[0]) - - -def test_perfect(): - actuals = [1, 2, 3, 4, 5, 6, 7, 8] - preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.float32) - - metric = KendallsTau(0, 10, 0.0, 1.0) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), 1.0) - - -def test_reversed(): - actuals = [1, 2, 3, 4, 5, 6, 7, 8] - preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8][::-1] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.float32) - - metric = KendallsTau(0, 10, 0.0, 1.0) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), -1.0) - - -def test_scoring_iterative(): - actuals = [12, 2, 1, 12, 2] - preds = [1, 4, 7, 1, 0] - - metric = KendallsTau(0, 13, 0, 8) - for actual, pred in zip(actuals, preds): - metric.update_state(tf.constant([[actual]]), tf.constant([[pred]])) - np.testing.assert_almost_equal(metric.result(), stats.kendalltau(actuals, preds)[0]) - - -@pytest.mark.usefixtures("maybe_run_functions_eagerly") -def test_keras_binary_classification_model(): - kp = KendallsTau() - inputs = tf.keras.layers.Input(shape=(10,)) - outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs) - model = tf.keras.models.Model(inputs, outputs) - model.compile(optimizer="sgd", loss="binary_crossentropy", metrics=[kp]) - - x = np.random.rand(1000, 10).astype(np.float32) - y = np.random.rand(1000, 1).astype(np.float32) - - history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) - assert not any(np.isnan(history.history["kendalls_tau"])) - - -def test_kendalls_tau_serialization(): - actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32) - preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32) - - kt = KendallsTau(0, 5, 0, 5, 10, 10) - check_metric_serialization(kt, actuals, preds) diff --git a/tensorflow_addons/metrics/tests/streaming_correlations_test.py b/tensorflow_addons/metrics/tests/streaming_correlations_test.py new file mode 100644 index 0000000000..046f3421d6 --- /dev/null +++ b/tensorflow_addons/metrics/tests/streaming_correlations_test.py @@ -0,0 +1,136 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for streaming correlations metrics.""" +import pytest +import numpy as np +import tensorflow as tf +from scipy import stats + + +from tensorflow_addons.metrics import KendallsTauB +from tensorflow_addons.metrics import KendallsTauC +from tensorflow_addons.metrics import PearsonsCorrelation +from tensorflow_addons.metrics import SpearmansRank +from tensorflow_addons.testing.serialization import check_metric_serialization + + +class TestStreamingCorrelations: + scipy_corr = { + KendallsTauB: lambda x, y: stats.kendalltau(x, y, variant="b"), + KendallsTauC: lambda x, y: stats.kendalltau(x, y, variant="c"), + SpearmansRank: stats.spearmanr, + PearsonsCorrelation: stats.pearsonr, + } + + testing_types = scipy_corr.keys() + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_config(self, correlation_type): + obj = correlation_type(name=correlation_type.__name__) + assert obj.name == correlation_type.__name__ + assert obj.dtype == tf.float32 + assert obj.actual_min == 0.0 + assert obj.actual_max == 1.0 + + # Check save and restore config + kp_obj2 = correlation_type.from_config(obj.get_config()) + assert kp_obj2.name == correlation_type.__name__ + assert kp_obj2.dtype == tf.float32 + assert kp_obj2.actual_min == 0.0 + assert kp_obj2.actual_max == 1.0 + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_scoring_with_ties(self, correlation_type): + actuals = [12, 2, 1, 12, 2] + preds = [1, 4, 7, 1, 0] + metric = correlation_type(0, 13, 0, 8) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_perfect(self, correlation_type): + actuals = [1, 2, 3, 4, 5, 6, 7, 8] + preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + metric = correlation_type(0, 10, 0.0, 1.0) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_reversed(self, correlation_type): + actuals = [1, 2, 3, 4, 5, 6, 7, 8] + preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8][::-1] + metric = correlation_type(0, 10, 0.0, 1.0) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_scoring_streaming(self, correlation_type): + actuals = [12, 2, 1, 12, 2] + preds = [1, 4, 7, 1, 0] + + metric = correlation_type(0, 13, 0, 8) + for actual, pred in zip(actuals, preds): + metric.update_state([[actual]], [[pred]]) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + @pytest.mark.usefixtures("maybe_run_functions_eagerly") + def test_keras_binary_classification_model(self, correlation_type): + metric = correlation_type() + inputs = tf.keras.layers.Input(shape=(128,)) + outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.1), + loss="binary_crossentropy", + metrics=[metric], + ) + + x = np.random.rand(1024, 128).astype(np.float32) + y = np.random.randint(2, size=(1024, 1)).astype(np.float32) + + initial_correlation = self.scipy_corr[correlation_type]( + model(x)[:, 0], y[:, 0] + )[0] + + history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) + + # the training should increase the correlation metric + assert np.all(history.history[metric.name] > initial_correlation) + + preds = model(x) + metric.reset_state() + # we decorate with tf.function to ensure the metric is also checked against graph mode. + # keras automatically decorates the metrics compiled within keras.Model. + tf.function(metric.update_state)(y, preds) + metric_value = tf.function(metric.result)() + scipy_value = self.scipy_corr[correlation_type](preds[:, 0], y[:, 0])[0] + np.testing.assert_almost_equal(metric_value, scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_serialization(self, correlation_type): + actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32) + preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32) + + kt = correlation_type(0, 5, 0, 5, 10, 10) + check_metric_serialization(kt, actuals, preds)