From 5f9015cc8c65f8f452c4fc365f6d68c95c0f6616 Mon Sep 17 00:00:00 2001 From: midsterx Date: Tue, 2 Nov 2021 01:33:13 +0530 Subject: [PATCH 1/4] Add MaxUnpooling2DV2 layer --- tensorflow_addons/layers/__init__.py | 1 + .../layers/max_unpooling_2d_v2.py | 97 +++++++++++++ .../layers/tests/max_unpooling_2d_v2_test.py | 134 ++++++++++++++++++ tools/testing/source_code_test.py | 1 + 4 files changed, 233 insertions(+) create mode 100644 tensorflow_addons/layers/max_unpooling_2d_v2.py create mode 100644 tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 32072f826a..edf3788748 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -26,6 +26,7 @@ from tensorflow_addons.layers.embedding_bag import EmbeddingBag from tensorflow_addons.layers.gelu import GELU from tensorflow_addons.layers.max_unpooling_2d import MaxUnpooling2D +from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.multihead_attention import MultiHeadAttention from tensorflow_addons.layers.normalizations import FilterResponseNormalization diff --git a/tensorflow_addons/layers/max_unpooling_2d_v2.py b/tensorflow_addons/layers/max_unpooling_2d_v2.py new file mode 100644 index 0000000000..3908762209 --- /dev/null +++ b/tensorflow_addons/layers/max_unpooling_2d_v2.py @@ -0,0 +1,97 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MaxUnpooling2DV2 operation.""" + +import tensorflow as tf + +from typeguard import typechecked +from typing import Iterable + +from tensorflow_addons.utils.keras_utils import normalize_tuple + + +def _max_unpooling_2d_v2(updates, mask, output_size): + """Unpool the outputs of a maximum pooling operation.""" + output_size_attr = " ".join(["i: %d" % v for v in output_size]) + experimental_implements = [ + 'name: "addons:MaxUnpooling2DV2"', + 'attr { key: "output_size" value { list {%s} } }' % output_size_attr, + ] + experimental_implements = " ".join(experimental_implements) + + @tf.function(experimental_implements=experimental_implements) + def func(updates, mask): + mask = tf.cast(mask, "int32") + input_shape = tf.shape(updates, out_type="int32") + input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] + output_shape = output_size + + # Calculates indices for batch, height, width and feature maps. + one_like_mask = tf.ones_like(mask, dtype="int32") + batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) + batch_range = tf.reshape( + tf.range(output_shape[0], dtype="int32"), shape=batch_shape + ) + b = one_like_mask * batch_range + y = mask // (output_shape[2] * output_shape[3]) + x = (mask // output_shape[3]) % output_shape[2] + feature_range = tf.range(output_shape[3], dtype="int32") + f = one_like_mask * feature_range + + # Transposes indices & reshape update values to one dimension. + updates_size = tf.size(updates) + indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) + values = tf.reshape(updates, [updates_size]) + ret = tf.scatter_nd(indices, values, output_shape) + return ret + + return func(updates, mask) + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class MaxUnpooling2DV2(tf.keras.layers.Layer): + """Unpool the outputs of a maximum pooling operation. + + This function currently does not support outputs of MaxPoolingWithArgMax in following cases: + - include_batch_in_index equals true. + - The max pooling operation results in duplicate values in updates and mask. + + Args: + updates: The pooling result from max pooling. + mask: the argmax result corresponds to above max values. + output_size: The targeted output size. + Input shape: + 4D tensor with shape: `(batch_size, height, width, channel)`. + Output shape: + 4D tensor with the same shape as `output_size`. + """ + + @typechecked + def __init__( + self, + output_size: Iterable[int], + **kwargs, + ): + super(MaxUnpooling2DV2, self).__init__(**kwargs) + + self.output_size = normalize_tuple(output_size, 4, "output_size") + + def call(self, updates, mask): + return _max_unpooling_2d_v2(updates, mask, output_size=self.output_size) + + def get_config(self): + config = super(MaxUnpooling2DV2, self).get_config() + config["output_size"] = self.output_size + return config diff --git a/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py new file mode 100644 index 0000000000..1b8bf82f82 --- /dev/null +++ b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @@ -0,0 +1,134 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaxUnpooling2DV2 layers.""" + +import numpy as np +import pytest +import tensorflow as tf +from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_simple(): + valid_input = np.array([13, 4]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 1, 2, 1)) + indices = np.array([1, 6]).astype(np.float32) + indices = np.reshape(indices, (1, 1, 2, 1)) + output_shape = (1, 2, 4, 1) + expected_output = np.array([0, 13, 0, 0, 0, 0, 4, 0]).astype(np.float32) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_complex(): + valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 2, 2, 2)) + indices = np.array([0, 3, 4, 7, 8, 11, 12, 15]).astype(np.float32) + indices = np.reshape(indices, (1, 2, 2, 2)) + output_shape = (1, 4, 2, 2) + expected_output = np.array([1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8]).astype( + np.float32 + ) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_batch(): + valid_input = np.array( + # fmt: off + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ] + # fmt: on + ).astype(np.float32) + valid_input = np.reshape(valid_input, (2, 2, 4, 2)) + indices = np.array( + # fmt: off + [ + 2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72, 75, 80, 79, 62, 65, 0, 1, 30, 7, + 14, 35, 42, 21, 68, 69, 50, 51, 56, 5, 86, 63 + ] + # fmt: on + ).astype(np.float32) + indices = np.reshape(indices, (2, 2, 4, 2)) + output_shape = (2, 4, 12, 2) + expected_output = np.array( + # fmt: off + [ + 0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 0, 0, 0, 0, 11, + 0, 0, 12, 0, 0, 0, 14, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0, + 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, + 0, 0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ] + # fmt: on + ).astype(np.float32) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices) + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_with_pooling_simple(): + valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 2, 4, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + expected_output = np.array([0, 0, 0, 0, 0, 6, 0, 8]).astype(np.float32) + expected_output = np.reshape(expected_output, valid_input.shape) + + output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_with_pooling(): + valid_input = np.array( + [1, 2, 4, 3, 8, 6, 7, 5, 9, 10, 12, 11, 13, 16, 15, 14] + ).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 4, 4, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + expected_output = np.array( + [0, 0, 0, 0, 8, 0, 7, 0, 0, 0, 0, 0, 0, 16, 15, 0] + ).astype(np.float32) + expected_output = np.reshape(expected_output, valid_input.shape) + + output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_symbolic_tensor_shape(): + valid_input = tf.keras.layers.Input((None, None, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + with pytest.raises(ValueError): + MaxUnpooling2DV2(valid_input.shape)(updates, indices) diff --git a/tools/testing/source_code_test.py b/tools/testing/source_code_test.py index c54bf73ea2..e3d22c23a2 100644 --- a/tools/testing/source_code_test.py +++ b/tools/testing/source_code_test.py @@ -149,6 +149,7 @@ def test_no_experimental_api(): allowlist = [ "tensorflow_addons/optimizers/weight_decay_optimizers.py", "tensorflow_addons/layers/max_unpooling_2d.py", + "tensorflow_addons/layers/max_unpooling_2d_v2.py", "tensorflow_addons/image/dense_image_warp.py", ] for file_path, line_idx, line in get_lines_of_source_code(allowlist): From 0a475ec2679246b38aa33441d63913e8bf156e99 Mon Sep 17 00:00:00 2001 From: midsterx Date: Mon, 15 Nov 2021 12:33:34 +0530 Subject: [PATCH 2/4] Resolve comments --- .../layers/max_unpooling_2d_v2.py | 59 ++++++++----------- .../layers/tests/max_unpooling_2d_v2_test.py | 2 +- tools/testing/source_code_test.py | 1 - 3 files changed, 25 insertions(+), 37 deletions(-) diff --git a/tensorflow_addons/layers/max_unpooling_2d_v2.py b/tensorflow_addons/layers/max_unpooling_2d_v2.py index 3908762209..eabcc927df 100644 --- a/tensorflow_addons/layers/max_unpooling_2d_v2.py +++ b/tensorflow_addons/layers/max_unpooling_2d_v2.py @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,40 +24,29 @@ def _max_unpooling_2d_v2(updates, mask, output_size): """Unpool the outputs of a maximum pooling operation.""" - output_size_attr = " ".join(["i: %d" % v for v in output_size]) - experimental_implements = [ - 'name: "addons:MaxUnpooling2DV2"', - 'attr { key: "output_size" value { list {%s} } }' % output_size_attr, - ] - experimental_implements = " ".join(experimental_implements) - - @tf.function(experimental_implements=experimental_implements) - def func(updates, mask): - mask = tf.cast(mask, "int32") - input_shape = tf.shape(updates, out_type="int32") - input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] - output_shape = output_size - - # Calculates indices for batch, height, width and feature maps. - one_like_mask = tf.ones_like(mask, dtype="int32") - batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) - batch_range = tf.reshape( - tf.range(output_shape[0], dtype="int32"), shape=batch_shape - ) - b = one_like_mask * batch_range - y = mask // (output_shape[2] * output_shape[3]) - x = (mask // output_shape[3]) % output_shape[2] - feature_range = tf.range(output_shape[3], dtype="int32") - f = one_like_mask * feature_range - - # Transposes indices & reshape update values to one dimension. - updates_size = tf.size(updates) - indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) - values = tf.reshape(updates, [updates_size]) - ret = tf.scatter_nd(indices, values, output_shape) - return ret - - return func(updates, mask) + mask = tf.cast(mask, "int32") + input_shape = tf.shape(updates, out_type="int32") + input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] + output_shape = output_size + + # Calculates indices for batch, height, width and feature maps. + one_like_mask = tf.ones_like(mask, dtype="int32") + batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) + batch_range = tf.reshape( + tf.range(output_shape[0], dtype="int32"), shape=batch_shape + ) + b = one_like_mask * batch_range + y = mask // (output_shape[2] * output_shape[3]) + x = (mask // output_shape[3]) % output_shape[2] + feature_range = tf.range(output_shape[3], dtype="int32") + f = one_like_mask * feature_range + + # Transposes indices & reshape update values to one dimension. + updates_size = tf.size(updates) + indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) + values = tf.reshape(updates, [updates_size]) + ret = tf.scatter_nd(indices, values, output_shape) + return ret @tf.keras.utils.register_keras_serializable(package="Addons") diff --git a/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py index 1b8bf82f82..d2382e91cf 100644 --- a/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py +++ b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @@ -1,4 +1,4 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tools/testing/source_code_test.py b/tools/testing/source_code_test.py index e3d22c23a2..c54bf73ea2 100644 --- a/tools/testing/source_code_test.py +++ b/tools/testing/source_code_test.py @@ -149,7 +149,6 @@ def test_no_experimental_api(): allowlist = [ "tensorflow_addons/optimizers/weight_decay_optimizers.py", "tensorflow_addons/layers/max_unpooling_2d.py", - "tensorflow_addons/layers/max_unpooling_2d_v2.py", "tensorflow_addons/image/dense_image_warp.py", ] for file_path, line_idx, line in get_lines_of_source_code(allowlist): From ac4981efbb2f2b0ed480dc13ca3c905d3109bead Mon Sep 17 00:00:00 2001 From: midsterx Date: Tue, 16 Nov 2021 01:41:10 +0530 Subject: [PATCH 3/4] Improve docstring and add to owners --- .github/CODEOWNERS | 2 ++ .../layers/max_unpooling_2d_v2.py | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7fd210c48e..6f723b740d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -94,6 +94,8 @@ /tensorflow_addons/layers/tests/noisy_dense_test.py @markub3327 /tensorflow_addons/layers/max_unpooling_2d.py @thaink /tensorflow_addons/layers/tests/max_unpooling_2d_test.py @thaink +/tensorflow_addons/layers/max_unpooling_2d_v2.py @midsterx +/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @midsterx /tensorflow_addons/layers/embedding_bag.py @rocketknight1 /tensorflow_addons/layers/tests/embedding_bag_test.py @rocketknight1 diff --git a/tensorflow_addons/layers/max_unpooling_2d_v2.py b/tensorflow_addons/layers/max_unpooling_2d_v2.py index eabcc927df..7ba7ca102b 100644 --- a/tensorflow_addons/layers/max_unpooling_2d_v2.py +++ b/tensorflow_addons/layers/max_unpooling_2d_v2.py @@ -53,18 +53,25 @@ def _max_unpooling_2d_v2(updates, mask, output_size): class MaxUnpooling2DV2(tf.keras.layers.Layer): """Unpool the outputs of a maximum pooling operation. + This differs from MaxUnpooling2D in that it uses output_size rather than strides and padding + to calculate the unpooled tensor. This is because MaxPoolingWithArgMax can map several input + sizes to the same output size, and specifying the output size avoids ambiguity in the + inversion process. + This function currently does not support outputs of MaxPoolingWithArgMax in following cases: - include_batch_in_index equals true. - The max pooling operation results in duplicate values in updates and mask. Args: - updates: The pooling result from max pooling. - mask: the argmax result corresponds to above max values. - output_size: The targeted output size. - Input shape: - 4D tensor with shape: `(batch_size, height, width, channel)`. + output_size: A tuple/list of 4 integers specifying (batch_size, height, width, channel). + The targeted output size. + Call Args: + updates: A 4D tensor of shape `(batch_size, height, width, channel)`. + The pooling result from max pooling. + mask: A 4D tensor of shape `(batch_size, height, width, channel)`. + The indices of the maximal values. Output shape: - 4D tensor with the same shape as `output_size`. + 4D tensor with the same shape as output_size. """ @typechecked From c8a22be1ec1d0df3a9ff43bf5b97237f0d843d5f Mon Sep 17 00:00:00 2001 From: midsterx Date: Tue, 16 Nov 2021 02:04:29 +0530 Subject: [PATCH 4/4] Fix spacing in docstring --- tensorflow_addons/layers/max_unpooling_2d_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/layers/max_unpooling_2d_v2.py b/tensorflow_addons/layers/max_unpooling_2d_v2.py index 7ba7ca102b..5d7fdd9288 100644 --- a/tensorflow_addons/layers/max_unpooling_2d_v2.py +++ b/tensorflow_addons/layers/max_unpooling_2d_v2.py @@ -71,7 +71,7 @@ class MaxUnpooling2DV2(tf.keras.layers.Layer): mask: A 4D tensor of shape `(batch_size, height, width, channel)`. The indices of the maximal values. Output shape: - 4D tensor with the same shape as output_size. + 4D tensor with the same shape as output_size. """ @typechecked