Skip to content

Commit f26abbb

Browse files
WindQAQseanpmorgan
authored andcommitted
[image/filters] convert unknown rank to 4d (#343)
* convert unknown rank to 4d
1 parent 4f243cb commit f26abbb

File tree

2 files changed

+45
-41
lines changed

2 files changed

+45
-41
lines changed

tensorflow_addons/image/filters.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21+
from tensorflow_addons.image import utils as img_utils
2122
from tensorflow_addons.utils import keras_utils
2223

2324

@@ -59,7 +60,8 @@ def mean_filter2d(image,
5960
"""Perform mean filtering on image(s).
6061
6162
Args:
62-
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
63+
image: Either a 2-D `Tensor` of shape `[height, width]`,
64+
a 3-D `Tensor` of shape `[height, width, channels]`,
6365
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
6466
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
6567
the height and width of the 2-D mean filter. Can be a single integer
@@ -74,16 +76,14 @@ def mean_filter2d(image,
7476
Returns:
7577
3-D or 4-D `Tensor` of the same dtype as input.
7678
Raises:
77-
ValueError: If `image` is not 3 or 4-dimensional,
79+
ValueError: If `image` is not 2, 3 or 4-dimensional,
7880
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
7981
or if `filter_shape` is invalid.
8082
"""
8183
with tf.name_scope(name or "mean_filter2d"):
8284
image = tf.convert_to_tensor(image, name="image")
83-
84-
rank = image.shape.rank
85-
if rank != 3 and rank != 4:
86-
raise ValueError("image should be either 3 or 4-dimensional.")
85+
original_ndims = img_utils.get_ndims(image)
86+
image = img_utils.to_4D_image(image)
8787

8888
if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
8989
raise ValueError(
@@ -93,10 +93,6 @@ def mean_filter2d(image,
9393
filter_shape = keras_utils.normalize_tuple(filter_shape, 2,
9494
"filter_shape")
9595

96-
# Expand to a 4-D tensor
97-
if rank == 3:
98-
image = tf.expand_dims(image, axis=0)
99-
10096
# Keep the precision if it's float;
10197
# otherwise, convert to float32 for computing.
10298
orig_dtype = image.dtype
@@ -119,11 +115,7 @@ def mean_filter2d(image,
119115

120116
output /= area
121117

122-
# Squeeze out the first axis to make sure
123-
# output has the same dimension with image.
124-
if rank == 3:
125-
output = tf.squeeze(output, axis=0)
126-
118+
output = img_utils.from_4D_image(output, original_ndims)
127119
return tf.dtypes.cast(output, orig_dtype)
128120

129121

@@ -136,7 +128,8 @@ def median_filter2d(image,
136128
"""Perform median filtering on image(s).
137129
138130
Args:
139-
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
131+
image: Either a 2-D `Tensor` of shape `[height, width]`,
132+
a 3-D `Tensor` of shape `[height, width, channels]`,
140133
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
141134
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
142135
the height and width of the 2-D median filter. Can be a single integer
@@ -151,16 +144,14 @@ def median_filter2d(image,
151144
Returns:
152145
3-D or 4-D `Tensor` of the same dtype as input.
153146
Raises:
154-
ValueError: If `image` is not 3 or 4-dimensional,
147+
ValueError: If `image` is not 2, 3 or 4-dimensional,
155148
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
156149
or if `filter_shape` is invalid.
157150
"""
158151
with tf.name_scope(name or "median_filter2d"):
159152
image = tf.convert_to_tensor(image, name="image")
160-
161-
rank = image.shape.rank
162-
if rank != 3 and rank != 4:
163-
raise ValueError("image should be either 3 or 4-dimensional.")
153+
original_ndims = img_utils.get_ndims(image)
154+
image = img_utils.to_4D_image(image)
164155

165156
if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
166157
raise ValueError(
@@ -170,10 +161,6 @@ def median_filter2d(image,
170161
filter_shape = keras_utils.normalize_tuple(filter_shape, 2,
171162
"filter_shape")
172163

173-
# Expand to a 4-D tensor
174-
if rank == 3:
175-
image = tf.expand_dims(image, axis=0)
176-
177164
image_shape = tf.shape(image)
178165
batch_size = image_shape[0]
179166
height = image_shape[1]
@@ -212,10 +199,5 @@ def median_filter2d(image,
212199
top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2
213200

214201
output = tf.cast(median, image.dtype)
215-
216-
# Squeeze out the first axis to make sure
217-
# output has the same dimension with image.
218-
if rank == 3:
219-
output = tf.squeeze(output, axis=0)
220-
202+
output = img_utils.from_4D_image(output, original_ndims)
221203
return output

tensorflow_addons/image/filters_test.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,13 @@ def setUp(self):
8282
super(MeanFilter2dTest, self).setUp()
8383

8484
def test_invalid_image(self):
85-
msg = "image should be either 3 or 4-dimensional."
86-
87-
for image_shape in [(28, 28), (16, 28, 28, 1, 1)]:
85+
msg = "`image` must be 2/3/4D tensor"
86+
errors = (ValueError, tf.errors.InvalidArgumentError)
87+
for image_shape in [(1,), (16, 28, 28, 1, 1)]:
8888
with self.subTest(dim=len(image_shape)):
89-
with self.assertRaisesRegexp(ValueError, msg):
90-
mean_filter2d(tf.ones(shape=image_shape))
89+
with self.assertRaisesRegexp(errors, msg):
90+
image = tf.ones(shape=image_shape)
91+
self.evaluate(mean_filter2d(image))
9192

9293
def test_invalid_filter_shape(self):
9394
msg = ("The `filter_shape` argument must be a tuple of 2 integers.")
@@ -119,6 +120,16 @@ def test_none_channels(self):
119120
fn(tf.ones(shape=(1, 3, 3, 1)))
120121
fn(tf.ones(shape=(1, 3, 3, 3)))
121122

123+
def test_unknown_shape(self):
124+
fn = mean_filter2d.get_concrete_function(
125+
tf.TensorSpec(shape=None, dtype=tf.dtypes.float32),
126+
padding="CONSTANT",
127+
constant_values=1.)
128+
129+
for shape in [(3, 3), (3, 3, 3), (1, 3, 3, 3)]:
130+
image = tf.ones(shape=shape)
131+
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))
132+
122133
def test_reflect_padding_with_3x3_filter(self):
123134
expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.],
124135
[42. / 9., 45. / 9., 48. / 9.],
@@ -191,12 +202,13 @@ def setUp(self):
191202
super(MedianFilter2dTest, self).setUp()
192203

193204
def test_invalid_image(self):
194-
msg = "image should be either 3 or 4-dimensional."
195-
196-
for image_shape in [(28, 28), (16, 28, 28, 1, 1)]:
205+
msg = "`image` must be 2/3/4D tensor"
206+
errors = (ValueError, tf.errors.InvalidArgumentError)
207+
for image_shape in [(1,), (16, 28, 28, 1, 1)]:
197208
with self.subTest(dim=len(image_shape)):
198-
with self.assertRaisesRegexp(ValueError, msg):
199-
median_filter2d(tf.ones(shape=image_shape))
209+
with self.assertRaisesRegexp(errors, msg):
210+
image = tf.ones(shape=image_shape)
211+
self.evaluate(median_filter2d(image))
200212

201213
def test_invalid_filter_shape(self):
202214
msg = ("The `filter_shape` argument must be a tuple of 2 integers.")
@@ -228,6 +240,16 @@ def test_none_channels(self):
228240
fn(tf.ones(shape=(1, 3, 3, 1)))
229241
fn(tf.ones(shape=(1, 3, 3, 3)))
230242

243+
def test_unknown_shape(self):
244+
fn = median_filter2d.get_concrete_function(
245+
tf.TensorSpec(shape=None, dtype=tf.dtypes.float32),
246+
padding="CONSTANT",
247+
constant_values=1.)
248+
249+
for shape in [(3, 3), (3, 3, 3), (1, 3, 3, 3)]:
250+
image = tf.ones(shape=shape)
251+
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))
252+
231253
def test_reflect_padding_with_3x3_filter(self):
232254
expected_plane = tf.constant([[4, 4, 5], [5, 5, 5], [5, 6, 6]])
233255

0 commit comments

Comments
 (0)