Skip to content

Commit 3966935

Browse files
authored
Add MaxUnpooling2DV2 layer (#2594)
* Add MaxUnpooling2DV2 layer * Resolve comments * Improve docstring and add to owners * Fix spacing in docstring
1 parent 11a86b4 commit 3966935

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@
9494
/tensorflow_addons/layers/tests/noisy_dense_test.py @markub3327
9595
/tensorflow_addons/layers/max_unpooling_2d.py @thaink
9696
/tensorflow_addons/layers/tests/max_unpooling_2d_test.py @thaink
97+
/tensorflow_addons/layers/max_unpooling_2d_v2.py @midsterx
98+
/tensorflow_addons/layers/tests/max_unpooling_2d_v2_test.py @midsterx
9799
/tensorflow_addons/layers/embedding_bag.py @rocketknight1
98100
/tensorflow_addons/layers/tests/embedding_bag_test.py @rocketknight1
99101

tensorflow_addons/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensorflow_addons.layers.embedding_bag import EmbeddingBag
2727
from tensorflow_addons.layers.gelu import GELU
2828
from tensorflow_addons.layers.max_unpooling_2d import MaxUnpooling2D
29+
from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2
2930
from tensorflow_addons.layers.maxout import Maxout
3031
from tensorflow_addons.layers.multihead_attention import MultiHeadAttention
3132
from tensorflow_addons.layers.normalizations import FilterResponseNormalization
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""MaxUnpooling2DV2 operation."""
16+
17+
import tensorflow as tf
18+
19+
from typeguard import typechecked
20+
from typing import Iterable
21+
22+
from tensorflow_addons.utils.keras_utils import normalize_tuple
23+
24+
25+
def _max_unpooling_2d_v2(updates, mask, output_size):
26+
"""Unpool the outputs of a maximum pooling operation."""
27+
mask = tf.cast(mask, "int32")
28+
input_shape = tf.shape(updates, out_type="int32")
29+
input_shape = [updates.shape[i] or input_shape[i] for i in range(4)]
30+
output_shape = output_size
31+
32+
# Calculates indices for batch, height, width and feature maps.
33+
one_like_mask = tf.ones_like(mask, dtype="int32")
34+
batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], axis=0)
35+
batch_range = tf.reshape(
36+
tf.range(output_shape[0], dtype="int32"), shape=batch_shape
37+
)
38+
b = one_like_mask * batch_range
39+
y = mask // (output_shape[2] * output_shape[3])
40+
x = (mask // output_shape[3]) % output_shape[2]
41+
feature_range = tf.range(output_shape[3], dtype="int32")
42+
f = one_like_mask * feature_range
43+
44+
# Transposes indices & reshape update values to one dimension.
45+
updates_size = tf.size(updates)
46+
indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
47+
values = tf.reshape(updates, [updates_size])
48+
ret = tf.scatter_nd(indices, values, output_shape)
49+
return ret
50+
51+
52+
@tf.keras.utils.register_keras_serializable(package="Addons")
53+
class MaxUnpooling2DV2(tf.keras.layers.Layer):
54+
"""Unpool the outputs of a maximum pooling operation.
55+
56+
This differs from MaxUnpooling2D in that it uses output_size rather than strides and padding
57+
to calculate the unpooled tensor. This is because MaxPoolingWithArgMax can map several input
58+
sizes to the same output size, and specifying the output size avoids ambiguity in the
59+
inversion process.
60+
61+
This function currently does not support outputs of MaxPoolingWithArgMax in following cases:
62+
- include_batch_in_index equals true.
63+
- The max pooling operation results in duplicate values in updates and mask.
64+
65+
Args:
66+
output_size: A tuple/list of 4 integers specifying (batch_size, height, width, channel).
67+
The targeted output size.
68+
Call Args:
69+
updates: A 4D tensor of shape `(batch_size, height, width, channel)`.
70+
The pooling result from max pooling.
71+
mask: A 4D tensor of shape `(batch_size, height, width, channel)`.
72+
The indices of the maximal values.
73+
Output shape:
74+
4D tensor with the same shape as output_size.
75+
"""
76+
77+
@typechecked
78+
def __init__(
79+
self,
80+
output_size: Iterable[int],
81+
**kwargs,
82+
):
83+
super(MaxUnpooling2DV2, self).__init__(**kwargs)
84+
85+
self.output_size = normalize_tuple(output_size, 4, "output_size")
86+
87+
def call(self, updates, mask):
88+
return _max_unpooling_2d_v2(updates, mask, output_size=self.output_size)
89+
90+
def get_config(self):
91+
config = super(MaxUnpooling2DV2, self).get_config()
92+
config["output_size"] = self.output_size
93+
return config
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Tests for MaxUnpooling2DV2 layers."""
16+
17+
import numpy as np
18+
import pytest
19+
import tensorflow as tf
20+
from tensorflow_addons.layers.max_unpooling_2d_v2 import MaxUnpooling2DV2
21+
22+
23+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
24+
def test_simple():
25+
valid_input = np.array([13, 4]).astype(np.float32)
26+
valid_input = np.reshape(valid_input, (1, 1, 2, 1))
27+
indices = np.array([1, 6]).astype(np.float32)
28+
indices = np.reshape(indices, (1, 1, 2, 1))
29+
output_shape = (1, 2, 4, 1)
30+
expected_output = np.array([0, 13, 0, 0, 0, 0, 4, 0]).astype(np.float32)
31+
expected_output = np.reshape(expected_output, output_shape)
32+
33+
output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy()
34+
np.testing.assert_array_equal(expected_output, output)
35+
36+
37+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
38+
def test_complex():
39+
valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32)
40+
valid_input = np.reshape(valid_input, (1, 2, 2, 2))
41+
indices = np.array([0, 3, 4, 7, 8, 11, 12, 15]).astype(np.float32)
42+
indices = np.reshape(indices, (1, 2, 2, 2))
43+
output_shape = (1, 4, 2, 2)
44+
expected_output = np.array([1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, 6, 7, 0, 0, 8]).astype(
45+
np.float32
46+
)
47+
expected_output = np.reshape(expected_output, output_shape)
48+
49+
output = MaxUnpooling2DV2(output_shape)(valid_input, indices).numpy()
50+
np.testing.assert_array_equal(expected_output, output)
51+
52+
53+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
54+
def test_batch():
55+
valid_input = np.array(
56+
# fmt: off
57+
[
58+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
59+
22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
60+
]
61+
# fmt: on
62+
).astype(np.float32)
63+
valid_input = np.reshape(valid_input, (2, 2, 4, 2))
64+
indices = np.array(
65+
# fmt: off
66+
[
67+
2, 23, 8, 9, 12, 15, 40, 43, 44, 47, 72, 75, 80, 79, 62, 65, 0, 1, 30, 7,
68+
14, 35, 42, 21, 68, 69, 50, 51, 56, 5, 86, 63
69+
]
70+
# fmt: on
71+
).astype(np.float32)
72+
indices = np.reshape(indices, (2, 2, 4, 2))
73+
output_shape = (2, 4, 12, 2)
74+
expected_output = np.array(
75+
# fmt: off
76+
[
77+
0, 0, 1, 0, 0, 0, 0, 0, 3, 4, 0, 0, 5, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 2, 0,
78+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 8, 9, 0, 0, 10, 0,
79+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0, 0, 16, 0, 0, 0, 0, 0, 0, 11,
80+
0, 0, 12, 0, 0, 0, 14, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
81+
17, 18, 0, 0, 0, 30, 0, 20, 0, 0, 0, 0, 0, 0, 21, 0, 0, 0, 0, 0, 0, 24, 0,
82+
0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 22, 0, 0, 0, 0, 0, 0, 23, 0, 0, 0, 0,
83+
0, 0, 0, 27, 28, 0, 0, 0, 0, 29, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 25, 26,
84+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 31, 0, 0, 0, 0, 0, 0, 0,
85+
0, 0
86+
]
87+
# fmt: on
88+
).astype(np.float32)
89+
expected_output = np.reshape(expected_output, output_shape)
90+
91+
output = MaxUnpooling2DV2(output_shape)(valid_input, indices)
92+
np.testing.assert_array_equal(expected_output, output)
93+
94+
95+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
96+
def test_with_pooling_simple():
97+
valid_input = np.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(np.float32)
98+
valid_input = np.reshape(valid_input, (1, 2, 4, 1))
99+
updates, indices = tf.nn.max_pool_with_argmax(
100+
valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME"
101+
)
102+
expected_output = np.array([0, 0, 0, 0, 0, 6, 0, 8]).astype(np.float32)
103+
expected_output = np.reshape(expected_output, valid_input.shape)
104+
105+
output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy()
106+
np.testing.assert_array_equal(expected_output, output)
107+
108+
109+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
110+
def test_with_pooling():
111+
valid_input = np.array(
112+
[1, 2, 4, 3, 8, 6, 7, 5, 9, 10, 12, 11, 13, 16, 15, 14]
113+
).astype(np.float32)
114+
valid_input = np.reshape(valid_input, (1, 4, 4, 1))
115+
updates, indices = tf.nn.max_pool_with_argmax(
116+
valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME"
117+
)
118+
expected_output = np.array(
119+
[0, 0, 0, 0, 8, 0, 7, 0, 0, 0, 0, 0, 0, 16, 15, 0]
120+
).astype(np.float32)
121+
expected_output = np.reshape(expected_output, valid_input.shape)
122+
123+
output = MaxUnpooling2DV2(valid_input.shape)(updates, indices).numpy()
124+
np.testing.assert_array_equal(expected_output, output)
125+
126+
127+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
128+
def test_symbolic_tensor_shape():
129+
valid_input = tf.keras.layers.Input((None, None, 1))
130+
updates, indices = tf.nn.max_pool_with_argmax(
131+
valid_input, ksize=[2, 2], strides=[2, 2], padding="SAME"
132+
)
133+
with pytest.raises(ValueError):
134+
MaxUnpooling2DV2(valid_input.shape)(updates, indices)

0 commit comments

Comments
 (0)