From d2c66fd62c7e1a566cfba796d42a26c56556fc2c Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 15 Aug 2022 20:43:42 +0200 Subject: [PATCH 1/2] remove unnecessary loop --- .../metrics/streaming_correlations.py | 47 +++++-------------- 1 file changed, 12 insertions(+), 35 deletions(-) diff --git a/tensorflow_addons/metrics/streaming_correlations.py b/tensorflow_addons/metrics/streaming_correlations.py index baee4a56c8..d2309ccc2e 100644 --- a/tensorflow_addons/metrics/streaming_correlations.py +++ b/tensorflow_addons/metrics/streaming_correlations.py @@ -114,42 +114,19 @@ def update_state(self, y_true, y_pred, sample_weight=None): - 1 ) - m = tf.sparse.from_dense(self.m) - nrow = tf.sparse.from_dense(self.nrow) - ncol = tf.sparse.from_dense(self.ncol) - - k = 0 - while k < tf.shape(i)[0]: - m = tf.sparse.add( - m, - tf.SparseTensor( - [[i[k], j[k]]], - tf.cast([1], dtype=m.dtype), - self.m.shape, - ), - ) - nrow = tf.sparse.add( - nrow, - tf.SparseTensor( - [[i[k]]], - tf.cast([1], dtype=nrow.dtype), - self.nrow.shape, - ), - ) - ncol = tf.sparse.add( - ncol, - tf.SparseTensor( - [[j[k]]], - tf.cast([1], dtype=ncol.dtype), - self.ncol.shape, - ), - ) - k += 1 + nrow = tf.tensor_scatter_nd_add( + self.nrow, tf.expand_dims(i, axis=-1), tf.ones_like(i) + ) + ncol = tf.tensor_scatter_nd_add( + self.ncol, tf.expand_dims(j, axis=-1), tf.ones_like(j) + ) + ij = tf.stack([i, j], axis=1) + m = tf.tensor_scatter_nd_add(self.m, ij, tf.ones_like(i)) - self.n.assign_add(tf.cast(k, tf.int64)) - self.m.assign(tf.sparse.to_dense(m)) - self.nrow.assign(tf.sparse.to_dense(nrow)) - self.ncol.assign(tf.sparse.to_dense(ncol)) + self.n.assign_add(tf.shape(i, out_type=tf.int64)[0]) + self.m.assign(m) + self.nrow.assign(nrow) + self.ncol.assign(ncol) @abstractmethod def result(self): From a8cce733ab4156a8dc692942bac433f25399dade Mon Sep 17 00:00:00 2001 From: Nicolas Pinchaud Date: Mon, 15 Aug 2022 21:46:11 +0200 Subject: [PATCH 2/2] tensor type fix --- tensorflow_addons/metrics/streaming_correlations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow_addons/metrics/streaming_correlations.py b/tensorflow_addons/metrics/streaming_correlations.py index d2309ccc2e..c6b8f975af 100644 --- a/tensorflow_addons/metrics/streaming_correlations.py +++ b/tensorflow_addons/metrics/streaming_correlations.py @@ -102,6 +102,7 @@ def update_state(self, y_true, y_pred, sample_weight=None): self.actual_cuts, tf.cast(tf.reshape(y_true, [-1]), self.actual_cuts.dtype), side="right", + out_type=tf.int64, ) - 1 ) @@ -110,6 +111,7 @@ def update_state(self, y_true, y_pred, sample_weight=None): self.preds_cuts, tf.cast(tf.reshape(y_pred, [-1]), self.preds_cuts.dtype), side="right", + out_type=tf.int64, ) - 1 )