-
Notifications
You must be signed in to change notification settings - Fork 617
Add MaxUnpooling2DV2 layer #2594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """MaxUnpooling2DV2 operation.""" | ||
|
|
||
| import tensorflow as tf | ||
|
|
||
| from typeguard import typechecked | ||
| from typing import Iterable | ||
|
|
||
| from tensorflow_addons.utils.keras_utils import normalize_tuple | ||
|
|
||
|
|
||
| def _max_unpooling_2d_v2(updates, mask, output_size): | ||
| """Unpool the outputs of a maximum pooling operation.""" | ||
| mask = tf.cast(mask, "int32") | ||
| input_shape = tf.shape(updates, out_type="int32") | ||
| input_shape = [updates.shape[i] or input_shape[i] for i in range(4)] | ||
| output_shape = output_size | ||
|
|
||
| # Calculates indices for batch, height, width and feature maps. | ||
| one_like_mask = tf.ones_like(mask, dtype="int32") | ||
| batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0) | ||
| batch_range = tf.reshape( | ||
| tf.range(output_shape[0], dtype="int32"), shape=batch_shape | ||
| ) | ||
| b = one_like_mask * batch_range | ||
| y = mask // (output_shape[2] * output_shape[3]) | ||
| x = (mask // output_shape[3]) % output_shape[2] | ||
| feature_range = tf.range(output_shape[3], dtype="int32") | ||
| f = one_like_mask * feature_range | ||
|
|
||
| # Transposes indices & reshape update values to one dimension. | ||
| updates_size = tf.size(updates) | ||
| indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) | ||
| values = tf.reshape(updates, [updates_size]) | ||
| ret = tf.scatter_nd(indices, values, output_shape) | ||
| return ret | ||
|
|
||
|
|
||
| @tf.keras.utils.register_keras_serializable(package="Addons") | ||
| class MaxUnpooling2DV2(tf.keras.layers.Layer): | ||
| """Unpool the outputs of a maximum pooling operation. | ||
|
|
||
| This differs from MaxUnpooling2D in that it uses output_size rather than strides and padding | ||
| to calculate the unpooled tensor. This is because MaxPoolingWithArgMax can map several input | ||
| sizes to the same output size, and specifying the output size avoids ambiguity in the | ||
| inversion process. | ||
|
|
||
| This function currently does not support outputs of MaxPoolingWithArgMax in following cases: | ||
| - include_batch_in_index equals true. | ||
| - The max pooling operation results in duplicate values in updates and mask. | ||
|
|
||
| Args: | ||
| output_size: A tuple/list of 4 integers specifying (batch_size, height, width, channel). | ||
| The targeted output size. | ||
| Call Args: | ||
| updates: A 4D tensor of shape `(batch_size, height, width, channel)`. | ||
| The pooling result from max pooling. | ||
| mask: A 4D tensor of shape `(batch_size, height, width, channel)`. | ||
| The indices of the maximal values. | ||
| Output shape: | ||
| 4D tensor with the same shape as output_size. | ||
| """ | ||
|
|
||
| @typechecked | ||
| def __init__( | ||
| self, | ||
| output_size: Iterable[int], | ||
| **kwargs, | ||
| ): | ||
| super(MaxUnpooling2DV2, self).__init__(**kwargs) | ||
|
|
||
| self.output_size = normalize_tuple(output_size, 4, "output_size") | ||
|
|
||
| def call(self, updates, mask): | ||
| return _max_unpooling_2d_v2(updates, mask, output_size=self.output_size) | ||
|
|
||
| def get_config(self): | ||
| config = super(MaxUnpooling2DV2, self).get_config() | ||
| config["output_size"] = self.output_size | ||
| return config |
134 changes: 134 additions & 0 deletions
134
tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Tests for MaxUnpooling2DV2 layers.""" | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import tensorflow as tf | ||
| from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2 | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_simple(): | ||
| valid_input = np.array([13, 4]).astype(np.float32) | ||
| valid_input = np.reshape(valid_input, (1, 1, 2, 1)) | ||
| indices = np.array([1, 6]).astype(np.float32) | ||
| indices = np.reshape(indices, (1, 1, 2, 1)) | ||
| output_shape = (1, 2, 4, 1) | ||
| expected_output = np.array([0, 13, 0, 0, 0, 0, 4, 0]).astype(np.float32) | ||
| expected_output = np.reshape(expected_output, output_shape) | ||
|
|
||
| output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() | ||
| np.testing.assert_array_equal(expected_output, output) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_complex(): | ||
| valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) | ||
| valid_input = np.reshape(valid_input, (1, 2, 2, 2)) | ||
| indices = np.array([0, 3, 4, 7, 8, 11, 12, 15]).astype(np.float32) | ||
| indices = np.reshape(indices, (1, 2, 2, 2)) | ||
| output_shape = (1, 4, 2, 2) | ||
| expected_output = np.array([1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8]).astype( | ||
| np.float32 | ||
| ) | ||
| expected_output = np.reshape(expected_output, output_shape) | ||
|
|
||
| output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy() | ||
| np.testing.assert_array_equal(expected_output, output) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_batch(): | ||
| valid_input = np.array( | ||
| # fmt: off | ||
| [ | ||
| 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, | ||
| 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 | ||
| ] | ||
| # fmt: on | ||
| ).astype(np.float32) | ||
| valid_input = np.reshape(valid_input, (2, 2, 4, 2)) | ||
| indices = np.array( | ||
| # fmt: off | ||
| [ | ||
| 2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72, 75, 80, 79, 62, 65, 0, 1, 30, 7, | ||
| 14, 35, 42, 21, 68, 69, 50, 51, 56, 5, 86, 63 | ||
| ] | ||
| # fmt: on | ||
| ).astype(np.float32) | ||
| indices = np.reshape(indices, (2, 2, 4, 2)) | ||
| output_shape = (2, 4, 12, 2) | ||
| expected_output = np.array( | ||
| # fmt: off | ||
| [ | ||
| 0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 2, 0, | ||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10, 0, | ||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 0, 0, 0, 0, 11, | ||
| 0, 0, 12, 0, 0, 0, 14, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | ||
| 17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0, | ||
| 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0, | ||
| 0, 0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26, | ||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0, | ||
| 0, 0 | ||
| ] | ||
| # fmt: on | ||
| ).astype(np.float32) | ||
| expected_output = np.reshape(expected_output, output_shape) | ||
|
|
||
| output = MaxUnpooling2DV2(output_shape)(valid_input, indices) | ||
| np.testing.assert_array_equal(expected_output, output) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_with_pooling_simple(): | ||
| valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32) | ||
| valid_input = np.reshape(valid_input, (1, 2, 4, 1)) | ||
| updates, indices = tf.nn.max_pool_with_argmax( | ||
| valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" | ||
| ) | ||
| expected_output = np.array([0, 0, 0, 0, 0, 6, 0, 8]).astype(np.float32) | ||
| expected_output = np.reshape(expected_output, valid_input.shape) | ||
|
|
||
| output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() | ||
| np.testing.assert_array_equal(expected_output, output) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_with_pooling(): | ||
| valid_input = np.array( | ||
| [1, 2, 4, 3, 8, 6, 7, 5, 9, 10, 12, 11, 13, 16, 15, 14] | ||
| ).astype(np.float32) | ||
| valid_input = np.reshape(valid_input, (1, 4, 4, 1)) | ||
| updates, indices = tf.nn.max_pool_with_argmax( | ||
| valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" | ||
| ) | ||
| expected_output = np.array( | ||
| [0, 0, 0, 0, 8, 0, 7, 0, 0, 0, 0, 0, 0, 16, 15, 0] | ||
| ).astype(np.float32) | ||
| expected_output = np.reshape(expected_output, valid_input.shape) | ||
|
|
||
| output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy() | ||
| np.testing.assert_array_equal(expected_output, output) | ||
|
|
||
|
|
||
| @pytest.mark.usefixtures("maybe_run_functions_eagerly") | ||
| def test_symbolic_tensor_shape(): | ||
| valid_input = tf.keras.layers.Input((None, None, 1)) | ||
| updates, indices = tf.nn.max_pool_with_argmax( | ||
| valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME" | ||
| ) | ||
| with pytest.raises(ValueError): | ||
| MaxUnpooling2DV2(valid_input.shape)(updates, indices) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a tflite test case too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide some clarity regarding this?
Did you mean something like
@pytest.mark.with_device(["cpu", "gpu", tf.distribute.MirroredStrategy])?