From 8ed8d83e84ccbb02b44431a97793745e22556146 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Fri, 5 Aug 2022 17:28:33 +0200 Subject: [PATCH 1/4] renaming kendalls_tau to streaming_correlations --- tensorflow_addons/metrics/__init__.py | 2 +- .../metrics/{kendalls_tau.py => streaming_correlations.py} | 0 .../{kendalls_tau_test.py => streaming_correlations_test.py} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename tensorflow_addons/metrics/{kendalls_tau.py => streaming_correlations.py} (100%) rename tensorflow_addons/metrics/tests/{kendalls_tau_test.py => streaming_correlations_test.py} (100%) diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py index 8a2d6b7464..4de19a8b5c 100755 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -31,4 +31,4 @@ from tensorflow_addons.metrics.r_square import RSquare from tensorflow_addons.metrics.geometric_mean import GeometricMean from tensorflow_addons.metrics.harmonic_mean import HarmonicMean -from tensorflow_addons.metrics.kendalls_tau import KendallsTau +from tensorflow_addons.metrics.streaming_correlations import KendallsTau diff --git a/tensorflow_addons/metrics/kendalls_tau.py b/tensorflow_addons/metrics/streaming_correlations.py similarity index 100% rename from tensorflow_addons/metrics/kendalls_tau.py rename to tensorflow_addons/metrics/streaming_correlations.py diff --git a/tensorflow_addons/metrics/tests/kendalls_tau_test.py b/tensorflow_addons/metrics/tests/streaming_correlations_test.py similarity index 100% rename from tensorflow_addons/metrics/tests/kendalls_tau_test.py rename to tensorflow_addons/metrics/tests/streaming_correlations_test.py From b4977f41b44263406c84f28effb64428a2224200 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Fri, 5 Aug 2022 15:44:51 +0200 Subject: [PATCH 2/4] Fix keras import --- .../metrics/streaming_correlations.py | 265 +++++++++++++----- .../tests/streaming_correlations_test.py | 196 +++++++------ 2 files changed, 315 insertions(+), 146 deletions(-) diff --git a/tensorflow_addons/metrics/streaming_correlations.py b/tensorflow_addons/metrics/streaming_correlations.py index 3f1a69f392..330205f189 100644 --- a/tensorflow_addons/metrics/streaming_correlations.py +++ b/tensorflow_addons/metrics/streaming_correlations.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,35 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Approximate Kendall's Tau-b Metric.""" +"""Approximate Pearson's, Spearman's, Kendall's Tau-b/c correlations based +on the algorithm of Wei Xiao https://arxiv.org/abs/1712.01521.""" +from abc import abstractmethod +import numpy as np import tensorflow as tf +from tensorflow.keras import backend from tensorflow.keras.metrics import Metric from tensorflow_addons.utils.types import AcceptableDTypes - from typeguard import typechecked -@tf.keras.utils.register_keras_serializable(package="Addons") -class KendallsTau(Metric): - """Computes Kendall's Tau-b Rank Correlation Coefficient. - - A measure of ordinal similarity between equal length sequences - of values, with allowances for ties. - - Based on the algorithm of Wei Xiao https://arxiv.org/abs/1712.01521. - - Usage: +class CorrelationBase(Metric): + """Base class for streaming correlation metrics. - ```python - actuals = tf.constant([12, 2, 1, 12, 2], dtype=np.int32) - preds = tf.constant([1, 4, 7, 1, 0], dtype=np.int32) + Based on https://arxiv.org/abs/1712.01521. - m = tfa.metrics.KendallsTau(0, 13) - m.update_state(actuals, preds) - print('Final result: ', m.result().numpy()) # Result: -0.4714045 - ``` + It stores and updates the joint and marginal histograms of (`y_true`, `y_pred`). + The concrete classes estimate the different correlation metrics + based on those histograms. """ @typechecked @@ -52,10 +44,10 @@ def __init__( preds_max: float = 1.0, actual_cutpoints: int = 100, preds_cutpoints: int = 100, - name: str = "kendalls_tau", + name: str = None, dtype: AcceptableDTypes = None, ): - """Creates a `KendallsTau` instance. + """Creates a `CorrelationBase` instance. Args: actual_min: the inclusive lower bound on values from actual. @@ -76,40 +68,50 @@ def __init__( self.preds_max = preds_max self.actual_cutpoints = actual_cutpoints self.preds_cutpoints = preds_cutpoints - self.actual_cuts = tf.linspace( + actual_cuts = np.linspace( tf.cast(self.actual_min, tf.float32), tf.cast(self.actual_max, tf.float32), - self.actual_cutpoints - 1, + self.actual_cutpoints, ) - self.preds_cuts = tf.linspace( + actual_cuts[-1] += backend.epsilon() + preds_cuts = np.linspace( tf.cast(self.preds_min, tf.float32), tf.cast(self.preds_max, tf.float32), - self.preds_cutpoints - 1, + self.preds_cutpoints, ) + preds_cuts[-1] += backend.epsilon() + self.actual_cuts = tf.convert_to_tensor(actual_cuts, tf.float32) + self.preds_cuts = tf.convert_to_tensor(preds_cuts, tf.float32) self.m = self.add_weight( - "m", (self.actual_cutpoints, self.preds_cutpoints), dtype=tf.int64 + "m", (self.actual_cutpoints - 1, self.preds_cutpoints - 1), dtype=tf.int64 ) - self.nrow = self.add_weight("nrow", (self.actual_cutpoints), dtype=tf.int64) - self.ncol = self.add_weight("ncol", (self.preds_cutpoints), dtype=tf.int64) + self.nrow = self.add_weight("nrow", (self.actual_cutpoints - 1), dtype=tf.int64) + self.ncol = self.add_weight("ncol", (self.preds_cutpoints - 1), dtype=tf.int64) self.n = self.add_weight("n", (), dtype=tf.int64) def update_state(self, y_true, y_pred, sample_weight=None): - """Accumulates ranks. + """Updates `m`, `nrow`, `ncol` respectively the joint and + marginal histograms of (`y_true`, `y_pred`) + """ - Args: - y_true: actual rank values. - y_pred: predicted rank values. - sample_weight (optional): Ignored. + y_true = tf.clip_by_value(y_true, self.actual_min, self.actual_max) + y_pred = tf.clip_by_value(y_pred, self.preds_min, self.preds_max) - Returns: - Update op. - """ - i = tf.searchsorted( - self.actual_cuts, - tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), + i = ( + tf.searchsorted( + self.actual_cuts, + tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), + side="right", + ) + - 1 ) - j = tf.searchsorted( - self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype) + j = ( + tf.searchsorted( + self.preds_cuts, + tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype), + side="right", + ) + - 1 ) m = tf.sparse.from_dense(self.m) @@ -149,7 +151,51 @@ def update_state(self, y_true, y_pred, sample_weight=None): self.nrow.assign(tf.sparse.to_dense(nrow)) self.ncol.assign(tf.sparse.to_dense(ncol)) + @abstractmethod def result(self): + pass + + def get_config(self): + """Returns the serializable config of the metric.""" + + config = { + "actual_min": self.actual_min, + "actual_max": self.actual_max, + "preds_min": self.preds_min, + "preds_max": self.preds_max, + "actual_cutpoints": self.actual_cutpoints, + "preds_cutpoints": self.preds_cutpoints, + } + base_config = super().get_config() + return {**base_config, **config} + + def reset_state(self): + """Resets all of the metric state variables.""" + + self.m.assign( + tf.zeros((self.actual_cutpoints - 1, self.preds_cutpoints - 1), tf.int64) + ) + self.nrow.assign(tf.zeros((self.actual_cutpoints - 1), tf.int64)) + self.ncol.assign(tf.zeros((self.preds_cutpoints - 1), tf.int64)) + self.n.assign(0) + + def reset_states(self): + # Backwards compatibility alias of `reset_state`. New classes should + # only implement `reset_state`. + # Required in Tensorflow < 2.5.0 + return self.reset_state() + + +class KendallsTauBase(CorrelationBase): + """Base class for kendall's tau metrics.""" + + def _compute_variables(self): + """Compute a tuple containing the concordant pairs, discordant pairs, + ties in `y_true` and `y_pred`. + + Returns: + A tuple + """ m = tf.cast(self.m, tf.float32) n_cap = tf.cumsum(tf.cumsum(m, axis=0), axis=1) # Number of concordant pairs. @@ -170,32 +216,121 @@ def result(self): # Number of discordant pairs. n = tf.cast(self.n, tf.float32) q = (n - 1.0) * n / 2.0 - p - t - u - b + return p, q, t, u + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class KendallsTauB(KendallsTauBase): + """Computes Kendall's Tau-b Rank Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.KendallsTauB(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.47140455 + """ + + def result(self): + p, q, t, u = self._compute_variables() return (p - q) / tf.math.sqrt((p + q + t) * (p + q + u)) - def get_config(self): - """Returns the serializable config of the metric.""" - config = { - "actual_min": self.actual_min, - "actual_max": self.actual_max, - "preds_min": self.preds_min, - "preds_max": self.preds_max, - "actual_cutpoints": self.actual_cutpoints, - "preds_cutpoints": self.preds_cutpoints, - } - base_config = super().get_config() - return {**base_config, **config} +@tf.keras.utils.register_keras_serializable(package="Addons") +class KendallsTauC(KendallsTauBase): + """Computes Kendall's Tau-c Rank Correlation Coefficient. - def reset_state(self): - """Resets all of the metric state variables.""" + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.KendallsTauC(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.48000002 + """ - self.m.assign(tf.zeros((self.actual_cutpoints, self.preds_cutpoints), tf.int64)) - self.nrow.assign(tf.zeros((self.actual_cutpoints), tf.int64)) - self.ncol.assign(tf.zeros((self.preds_cutpoints), tf.int64)) - self.n.assign(0) + def result(self): + p, q, _, _ = self._compute_variables() + n = tf.cast(self.n, tf.float32) + non_zeros_col = tf.math.count_nonzero(self.ncol) + non_zeros_row = tf.math.count_nonzero(self.nrow) + m = tf.cast(tf.minimum(non_zeros_col, non_zeros_row), tf.float32) + return 2 * (p - q) / (tf.square(n) * (m - 1) / m) - def reset_states(self): - # Backwards compatibility alias of `reset_state`. New classes should - # only implement `reset_state`. - # Required in Tensorflow < 2.5.0 - return self.reset_state() + +@tf.keras.utils.register_keras_serializable(package="Addons") +class SpearmansRank(CorrelationBase): + """Computes Spearman's Rank Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.SpearmansRank(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.54073805 + """ + + def result(self): + nrow = tf.cast(self.nrow, tf.float32) + ncol = tf.cast(self.ncol, tf.float32) + n = tf.cast(self.n, tf.float32) + + nrow_ = tf.where(nrow > 0, nrow, -1.0) + rrow = tf.pad(tf.cumsum(nrow)[:-1], [[1, 0]]) + (nrow_ - n) / 2 + ncol_ = tf.where(ncol > 0, ncol, -1.0) + rcol = tf.pad(tf.cumsum(ncol)[:-1], [[1, 0]]) + (ncol_ - n) / 2 + + rrow = rrow / tf.math.sqrt(tf.reduce_sum(nrow * tf.square(rrow))) + rcol = rcol / tf.math.sqrt(tf.reduce_sum(ncol * tf.square(rcol))) + + m = tf.cast(self.m, tf.float32) + corr = tf.matmul(tf.expand_dims(rrow, axis=0), m) + corr = tf.matmul(corr, tf.expand_dims(rcol, axis=1)) + return tf.squeeze(corr) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class PearsonsCorrelation(CorrelationBase): + """Computes Pearsons's Correlation Coefficient. + + Usage: + >>> actuals = tf.constant([12, 2, 1, 12, 2], dtype=tf.int32) + >>> preds = tf.constant([1, 4, 7, 1, 0], dtype=tf.int32) + >>> m = tfa.metrics.PearsonsCorrelation(0, 13, 0, 8) + >>> m.update_state(actuals, preds) + >>> m.result().numpy() + -0.5618297 + """ + + def result(self): + ncol = tf.cast(self.ncol, tf.float32) + nrow = tf.cast(self.nrow, tf.float32) + m = tf.cast(self.m, tf.float32) + n = tf.cast(self.n, tf.float32) + + col_bins = (self.preds_cuts[1:] - self.preds_cuts[:-1]) / 2.0 + self.preds_cuts[ + :-1 + ] + row_bins = ( + self.actual_cuts[1:] - self.actual_cuts[:-1] + ) / 2.0 + self.actual_cuts[:-1] + + n_col = tf.reduce_sum(ncol) + n_row = tf.reduce_sum(nrow) + col_mean = tf.reduce_sum(ncol * col_bins) / n_col + row_mean = tf.reduce_sum(nrow * row_bins) / n_row + + col_var = tf.reduce_sum(ncol * tf.square(col_bins)) - n_col * tf.square( + col_mean + ) + row_var = tf.reduce_sum(nrow * tf.square(row_bins)) - n_row * tf.square( + row_mean + ) + + joint_product = m * tf.expand_dims(row_bins, axis=1) * col_bins + + corr = tf.reduce_sum(joint_product) - n * col_mean * row_mean + + return corr / tf.sqrt(col_var * row_var) diff --git a/tensorflow_addons/metrics/tests/streaming_correlations_test.py b/tensorflow_addons/metrics/tests/streaming_correlations_test.py index 4121c64b5e..f9eef283c6 100644 --- a/tensorflow_addons/metrics/tests/streaming_correlations_test.py +++ b/tensorflow_addons/metrics/tests/streaming_correlations_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,91 +12,125 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Kendall's Tau-b Metric.""" +"""Tests for streaming correlations metrics.""" import pytest import numpy as np import tensorflow as tf from scipy import stats -from tensorflow_addons.metrics import KendallsTau -from tensorflow_addons.testing.serialization import check_metric_serialization - - -def test_config(): - kp_obj = KendallsTau(name="kendalls_tau") - assert kp_obj.name == "kendalls_tau" - assert kp_obj.dtype == tf.float32 - assert kp_obj.actual_min == 0.0 - assert kp_obj.actual_max == 1.0 - - # Check save and restore config - kp_obj2 = KendallsTau.from_config(kp_obj.get_config()) - assert kp_obj2.name == "kendalls_tau" - assert kp_obj2.dtype == tf.float32 - assert kp_obj2.actual_min == 0.0 - assert kp_obj2.actual_max == 1.0 - - -def test_scoring_with_ties(): - actuals = [12, 2, 1, 12, 2] - preds = [1, 4, 7, 1, 0] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.int32) - - metric = KendallsTau(0, 13, 0, 8) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), stats.kendalltau(actuals, preds)[0]) - - -def test_perfect(): - actuals = [1, 2, 3, 4, 5, 6, 7, 8] - preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.float32) - - metric = KendallsTau(0, 10, 0.0, 1.0) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), 1.0) -def test_reversed(): - actuals = [1, 2, 3, 4, 5, 6, 7, 8] - preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8][::-1] - actuals = tf.constant(actuals, dtype=tf.int32) - preds = tf.constant(preds, dtype=tf.float32) - - metric = KendallsTau(0, 10, 0.0, 1.0) - metric.update_state(actuals, preds) - np.testing.assert_almost_equal(metric.result(), -1.0) - - -def test_scoring_iterative(): - actuals = [12, 2, 1, 12, 2] - preds = [1, 4, 7, 1, 0] - - metric = KendallsTau(0, 13, 0, 8) - for actual, pred in zip(actuals, preds): - metric.update_state(tf.constant([[actual]]), tf.constant([[pred]])) - np.testing.assert_almost_equal(metric.result(), stats.kendalltau(actuals, preds)[0]) - - -@pytest.mark.usefixtures("maybe_run_functions_eagerly") -def test_keras_binary_classification_model(): - kp = KendallsTau() - inputs = tf.keras.layers.Input(shape=(10,)) - outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs) - model = tf.keras.models.Model(inputs, outputs) - model.compile(optimizer="sgd", loss="binary_crossentropy", metrics=[kp]) - - x = np.random.rand(1000, 10).astype(np.float32) - y = np.random.rand(1000, 1).astype(np.float32) - - history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) - assert not any(np.isnan(history.history["kendalls_tau"])) - +from tensorflow_addons.metrics import KendallsTauB +from tensorflow_addons.metrics import KendallsTauC +from tensorflow_addons.metrics import PearsonsCorrelation +from tensorflow_addons.metrics import SpearmansRank +from tensorflow_addons.testing.serialization import check_metric_serialization -def test_kendalls_tau_serialization(): - actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32) - preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32) - kt = KendallsTau(0, 5, 0, 5, 10, 10) - check_metric_serialization(kt, actuals, preds) +class TestStreamingCorrelations: + scipy_corr = { + KendallsTauB: lambda x, y: stats.kendalltau(x, y, variant="b"), + KendallsTauC: lambda x, y: stats.kendalltau(x, y, variant="c"), + SpearmansRank: stats.spearmanr, + PearsonsCorrelation: stats.pearsonr, + } + + testing_types = scipy_corr.keys() + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_config(self, correlation_type): + obj = correlation_type(name=correlation_type.__name__) + assert obj.name == correlation_type.__name__ + assert obj.dtype == tf.float32 + assert obj.actual_min == 0.0 + assert obj.actual_max == 1.0 + + # Check save and restore config + kp_obj2 = correlation_type.from_config(obj.get_config()) + assert kp_obj2.name == correlation_type.__name__ + assert kp_obj2.dtype == tf.float32 + assert kp_obj2.actual_min == 0.0 + assert kp_obj2.actual_max == 1.0 + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_scoring_with_ties(self, correlation_type): + actuals = [12, 2, 1, 12, 2] + preds = [1, 4, 7, 1, 0] + metric = correlation_type(0, 13, 0, 8) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_perfect(self, correlation_type): + actuals = [1, 2, 3, 4, 5, 6, 7, 8] + preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] + metric = correlation_type(0, 10, 0.0, 1.0) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_reversed(self, correlation_type): + actuals = [1, 2, 3, 4, 5, 6, 7, 8] + preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8][::-1] + metric = correlation_type(0, 10, 0.0, 1.0) + metric.update_state(actuals, preds) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_scoring_streaming(self, correlation_type): + actuals = [12, 2, 1, 12, 2] + preds = [1, 4, 7, 1, 0] + + metric = correlation_type(0, 13, 0, 8) + for actual, pred in zip(actuals, preds): + metric.update_state([[actual]], [[pred]]) + + scipy_value = self.scipy_corr[correlation_type](actuals, preds)[0] + np.testing.assert_almost_equal(metric.result(), scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + @pytest.mark.usefixtures("maybe_run_functions_eagerly") + def test_keras_binary_classification_model(self, correlation_type): + metric = correlation_type() + inputs = tf.keras.layers.Input(shape=(128,)) + outputs = tf.keras.layers.Dense(1, activation="sigmoid")(inputs) + model = tf.keras.models.Model(inputs, outputs) + model.compile( + optimizer=tf.keras.optimizers.Adam(learning_rate=0.1), + loss="binary_crossentropy", + metrics=[metric], + ) + + x = np.random.rand(1024, 128).astype(np.float32) + y = np.random.randint(2, size=(1024, 1)).astype(np.float32) + + initial_correlation = self.scipy_corr[correlation_type]( + model(x)[:, 0], y[:, 0] + )[0] + + history = model.fit(x, y, epochs=1, verbose=0, batch_size=32) + + # the training should increase the correlation metric + assert np.all(history.history[metric.name] > initial_correlation) + + preds = model(x) + metric.reset_state() + # we decorate with tf.function to ensure the metric is also checked against graph mode. + # keras automatically decorates the metrics compiled within keras.Model. + tf.function(metric.update_state)(y, preds) + metric_value = tf.function(metric.result)() + scipy_value = self.scipy_corr[correlation_type](preds[:, 0], y[:, 0])[0] + np.testing.assert_almost_equal(metric_value, scipy_value, decimal=2) + + @pytest.mark.parametrize("correlation_type", testing_types) + def test_kendalls_tau_serialization(self, correlation_type): + actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32) + preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32) + + kt = correlation_type(0, 5, 0, 5, 10, 10) + check_metric_serialization(kt, actuals, preds) From acc4bc8fd0618714bd85b4888415b1c8d80a8fb2 Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Fri, 5 Aug 2022 16:11:21 +0200 Subject: [PATCH 3/4] Fix test function naming --- tensorflow_addons/metrics/tests/streaming_correlations_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/metrics/tests/streaming_correlations_test.py b/tensorflow_addons/metrics/tests/streaming_correlations_test.py index f9eef283c6..046f3421d6 100644 --- a/tensorflow_addons/metrics/tests/streaming_correlations_test.py +++ b/tensorflow_addons/metrics/tests/streaming_correlations_test.py @@ -128,7 +128,7 @@ def test_keras_binary_classification_model(self, correlation_type): np.testing.assert_almost_equal(metric_value, scipy_value, decimal=2) @pytest.mark.parametrize("correlation_type", testing_types) - def test_kendalls_tau_serialization(self, correlation_type): + def test_serialization(self, correlation_type): actuals = np.array([4, 4, 3, 3, 2, 2, 1, 1], dtype=np.int32) preds = np.array([1, 2, 4, 1, 3, 3, 4, 4], dtype=np.int32) From aed94f3cc69587707d38d3263df7f1a2b5a0914c Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Fri, 5 Aug 2022 17:33:54 +0200 Subject: [PATCH 4/4] Fixing commit history to improve review experience --- tensorflow_addons/metrics/__init__.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/metrics/__init__.py b/tensorflow_addons/metrics/__init__.py index 4de19a8b5c..61a3b3f3bc 100755 --- a/tensorflow_addons/metrics/__init__.py +++ b/tensorflow_addons/metrics/__init__.py @@ -31,4 +31,9 @@ from tensorflow_addons.metrics.r_square import RSquare from tensorflow_addons.metrics.geometric_mean import GeometricMean from tensorflow_addons.metrics.harmonic_mean import HarmonicMean -from tensorflow_addons.metrics.streaming_correlations import KendallsTau +from tensorflow_addons.metrics.streaming_correlations import ( + KendallsTauB, + KendallsTauC, + PearsonsCorrelation, + SpearmansRank, +)