diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7fd210c48e..6f723b740d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -94,6 +94,8 @@ /tensorflow_addons/layers/tests/noisy_dense_test.py @markub3327 /tensorflow_addons/layers/max_unpooling_2d.py @thaink /tensorflow_addons/layers/tests/max_unpooling_2d_test.py @thaink +/tensorflow_addons/layers/max_unpooling_2d_v2.py @midsterx +/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @midsterx /tensorflow_addons/layers/embedding_bag.py @rocketknight1 /tensorflow_addons/layers/tests/embedding_bag_test.py @rocketknight1 diff --git a/tensorflow_addons/layers/__init__.py b/tensorflow_addons/layers/__init__.py index 32072f826a..edf3788748 100644 --- a/tensorflow_addons/layers/__init__.py +++ b/tensorflow_addons/layers/__init__.py @@ -26,6 +26,7 @@ from tensorflow_addons.layers.embedding_bag import EmbeddingBag from tensorflow_addons.layers.gelu import GELU from tensorflow_addons.layers.max_unpooling_2d import MaxUnpooling2D +from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 from tensorflow_addons.layers.maxout import Maxout from tensorflow_addons.layers.multihead_attention import MultiHeadAttention from tensorflow_addons.layers.normalizations import FilterResponseNormalization diff --git a/tensorflow_addons/layers/max_unpooling_2d_v2.py b/tensorflow_addons/layers/max_unpooling_2d_v2.py new file mode 100644 index 0000000000..5d7fdd9288 --- /dev/null +++ b/tensorflow_addons/layers/max_unpooling_2d_v2.py @@ -0,0 +1,93 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MaxUnpooling2DV2 operation.""" + +import tensorflow as tf + +from typeguard import typechecked +from typing import Iterable + +from tensorflow_addons.utils.keras_utils import normalize_tuple + + +def _max_unpooling_2d_v2(updates, mask, output_size): + """Unpool the outputs of a maximum pooling operation.""" + mask = tf.cast(mask, "int32") + input_shape = tf.shape(updates, out_type="int32") + input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] + output_shape = output_size + + # Calculates indices for batch, height, width and feature maps. + one_like_mask = tf.ones_like(mask, dtype="int32") + batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) + batch_range = tf.reshape( + tf.range(output_shape[0], dtype="int32"), shape=batch_shape + ) + b = one_like_mask * batch_range + y = mask // (output_shape[2] * output_shape[3]) + x = (mask // output_shape[3]) % output_shape[2] + feature_range = tf.range(output_shape[3], dtype="int32") + f = one_like_mask * feature_range + + # Transposes indices & reshape update values to one dimension. + updates_size = tf.size(updates) + indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) + values = tf.reshape(updates, [updates_size]) + ret = tf.scatter_nd(indices, values, output_shape) + return ret + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class MaxUnpooling2DV2(tf.keras.layers.Layer): + """Unpool the outputs of a maximum pooling operation. + + This differs from MaxUnpooling2D in that it uses output_size rather than strides and padding + to calculate the unpooled tensor. This is because MaxPoolingWithArgMax can map several input + sizes to the same output size, and specifying the output size avoids ambiguity in the + inversion process. + + This function currently does not support outputs of MaxPoolingWithArgMax in following cases: + - include_batch_in_index equals true. + - The max pooling operation results in duplicate values in updates and mask. + + Args: + output_size: A tuple/list of 4 integers specifying (batch_size, height, width, channel). + The targeted output size. + Call Args: + updates: A 4D tensor of shape `(batch_size, height, width, channel)`. + The pooling result from max pooling. + mask: A 4D tensor of shape `(batch_size, height, width, channel)`. + The indices of the maximal values. + Output shape: + 4D tensor with the same shape as output_size. + """ + + @typechecked + def __init__( + self, + output_size: Iterable[int], + **kwargs, + ): + super(MaxUnpooling2DV2, self).__init__(**kwargs) + + self.output_size = normalize_tuple(output_size, 4, "output_size") + + def call(self, updates, mask): + return _max_unpooling_2d_v2(updates, mask, output_size=self.output_size) + + def get_config(self): + config = super(MaxUnpooling2DV2, self).get_config() + config["output_size"] = self.output_size + return config diff --git a/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py new file mode 100644 index 0000000000..d2382e91cf --- /dev/null +++ b/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @@ -0,0 +1,134 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaxUnpooling2DV2 layers.""" + +import numpy as np +import pytest +import tensorflow as tf +from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_simple(): + valid_input = np.array([13, 4]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 1, 2, 1)) + indices = np.array([1, 6]).astype(np.float32) + indices = np.reshape(indices, (1, 1, 2, 1)) + output_shape = (1, 2, 4, 1) + expected_output = np.array([0, 13, 0, 0, 0, 0, 4, 0]).astype(np.float32) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_complex(): + valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 2, 2, 2)) + indices = np.array([0, 3, 4, 7, 8, 11, 12, 15]).astype(np.float32) + indices = np.reshape(indices, (1, 2, 2, 2)) + output_shape = (1, 4, 2, 2) + expected_output = np.array([1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8]).astype( + np.float32 + ) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_batch(): + valid_input = np.array( + # fmt: off + [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 + ] + # fmt: on + ).astype(np.float32) + valid_input = np.reshape(valid_input, (2, 2, 4, 2)) + indices = np.array( + # fmt: off + [ + 2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72, 75, 80, 79, 62, 65, 0, 1, 30, 7, + 14, 35, 42, 21, 68, 69, 50, 51, 56, 5, 86, 63 + ] + # fmt: on + ).astype(np.float32) + indices = np.reshape(indices, (2, 2, 4, 2)) + output_shape = (2, 4, 12, 2) + expected_output = np.array( + # fmt: off + [ + 0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 2, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 0, 0, 0, 0, 11, + 0, 0, 12, 0, 0, 0, 14, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0, + 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, + 0, 0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, + 0, 0 + ] + # fmt: on + ).astype(np.float32) + expected_output = np.reshape(expected_output, output_shape) + + output = MaxUnpooling2DV2(output_shape)(valid_input, indices) + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_with_pooling_simple(): + valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 2, 4, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + expected_output = np.array([0, 0, 0, 0, 0, 6, 0, 8]).astype(np.float32) + expected_output = np.reshape(expected_output, valid_input.shape) + + output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_with_pooling(): + valid_input = np.array( + [1, 2, 4, 3, 8, 6, 7, 5, 9, 10, 12, 11, 13, 16, 15, 14] + ).astype(np.float32) + valid_input = np.reshape(valid_input, (1, 4, 4, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + expected_output = np.array( + [0, 0, 0, 0, 8, 0, 7, 0, 0, 0, 0, 0, 0, 16, 15, 0] + ).astype(np.float32) + expected_output = np.reshape(expected_output, valid_input.shape) + + output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() + np.testing.assert_array_equal(expected_output, output) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_symbolic_tensor_shape(): + valid_input = tf.keras.layers.Input((None, None, 1)) + updates, indices = tf.nn.max_pool_with_argmax( + valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" + ) + with pytest.raises(ValueError): + MaxUnpooling2DV2(valid_input.shape)(updates, indices)