Skip to content

Commit a813255

Browse files
committed
Merge branch 'master' into BLD/add_cuda_py_test
2 parents 075e641 + f2a2e2e commit a813255

File tree

6 files changed

+70
-74
lines changed

6 files changed

+70
-74
lines changed

examples/optimizers_lazyadam.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
"source": [
9191
"# LazyAdam\n",
9292
"\n",
93-
"> LazyAdam is a variant of the Adam optimizer that handles sparse updates moreefficiently.\n",
93+
"> LazyAdam is a variant of the Adam optimizer that handles sparse updates more efficiently.\n",
9494
" The original Adam algorithm maintains two moving-average accumulators for\n",
9595
" each trainable variable; the accumulators are updated at every step.\n",
9696
" This class provides lazier handling of gradient updates for sparse\n",

tensorflow_addons/image/distance_transform.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21+
from tensorflow_addons.image import utils as img_utils
2122
from tensorflow_addons.utils.resource_loader import get_path_to_datafile
2223

2324
_image_ops_so = tf.load_op_library(
@@ -32,8 +33,7 @@ def euclidean_dist_transform(images, dtype=tf.float32, name=None):
3233
3334
Args:
3435
images: A tensor of shape (num_images, num_rows, num_columns, 1) (NHWC),
35-
or (num_rows, num_columns, 1) (HWC). The rank must be statically known
36-
(the shape is not `TensorShape(None)`.
36+
or (num_rows, num_columns, 1) (HWC) or (num_rows, num_columns) (HW).
3737
dtype: DType of the output tensor.
3838
name: The name of the op.
3939
@@ -45,7 +45,7 @@ def euclidean_dist_transform(images, dtype=tf.float32, name=None):
4545
Raises:
4646
TypeError: If `image` is not tf.uint8, or `dtype` is not floating point.
4747
ValueError: If `image` more than one channel, or `image` is not of
48-
rank 3 or 4.
48+
rank between 2 and 4.
4949
"""
5050

5151
with tf.name_scope(name or "euclidean_distance_transform"):
@@ -54,14 +54,9 @@ def euclidean_dist_transform(images, dtype=tf.float32, name=None):
5454
if image_or_images.dtype.base_dtype != tf.uint8:
5555
raise TypeError(
5656
"Invalid dtype %s. Expected uint8." % image_or_images.dtype)
57-
if image_or_images.get_shape().ndims is None:
58-
raise ValueError("`images` rank must be statically known")
59-
elif len(image_or_images.get_shape()) == 3:
60-
images = image_or_images[None, :, :, :]
61-
elif len(image_or_images.get_shape()) == 4:
62-
images = image_or_images
63-
else:
64-
raise ValueError("`images` should have rank between 3 and 4")
57+
58+
images = img_utils.to_4D_image(image_or_images)
59+
original_ndims = img_utils.get_ndims(image_or_images)
6560

6661
if images.get_shape()[3] != 1 and images.get_shape()[3] is not None:
6762
raise ValueError("`images` must have only one channel")
@@ -72,6 +67,4 @@ def euclidean_dist_transform(images, dtype=tf.float32, name=None):
7267
images = tf.cast(images, dtype)
7368
output = _image_ops_so.euclidean_distance_transform(images)
7469

75-
if len(image_or_images.get_shape()) == 3:
76-
return output[0, :, :, :]
77-
return output
70+
return img_utils.from_4D_image(output, original_ndims)

tensorflow_addons/image/distance_transform_test.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import tensorflow as tf
2323

24-
from tensorflow_addons.image import distance_transform as distance_tranform_ops
24+
from tensorflow_addons.image import distance_transform as dist_ops
2525
from tensorflow_addons.utils import test_utils
2626

2727

@@ -45,7 +45,7 @@ def test_single_binary_image(self):
4545
image = tf.constant(image, dtype=tf.uint8)
4646

4747
for output_dtype in [tf.float16, tf.float32, tf.float64]:
48-
output = distance_tranform_ops.euclidean_dist_transform(
48+
output = dist_ops.euclidean_dist_transform(
4949
image, dtype=output_dtype)
5050
output_flat = tf.reshape(output, [-1])
5151

@@ -73,7 +73,7 @@ def test_batch_binary_images(self):
7373
# yapf: enable
7474
images = tf.constant([image] * batch_size, dtype=tf.uint8)
7575
for output_dtype in [tf.float16, tf.float32, tf.float64]:
76-
output = distance_tranform_ops.euclidean_dist_transform(
76+
output = dist_ops.euclidean_dist_transform(
7777
images, dtype=output_dtype)
7878
output_flat = tf.reshape(output, [-1])
7979

@@ -97,38 +97,37 @@ def test_image_with_invalid_dtype(self):
9797
# pylint: disable=bad-continuation
9898
with self.assertRaisesRegex(
9999
TypeError, "`dtype` must be float16, float32 or float64"):
100-
_ = distance_tranform_ops.euclidean_dist_transform(
100+
_ = dist_ops.euclidean_dist_transform(
101101
image, dtype=output_dtype)
102102

103103
def test_image_with_invalid_shape(self):
104-
for invalid_shape in ([1], [2, 1], [2, 4, 4, 4, 1]):
105-
image = tf.zeros(invalid_shape, tf.uint8)
106-
107-
# pylint: disable=bad-continuation
108-
with self.assertRaisesRegex(
109-
ValueError, "`images` should have rank between 3 and 4"):
110-
_ = distance_tranform_ops.euclidean_dist_transform(image)
111-
112104
image = tf.zeros([2, 4, 3], tf.uint8)
113105
with self.assertRaisesRegex(ValueError,
114106
"`images` must have only one channel"):
115-
_ = distance_tranform_ops.euclidean_dist_transform(image)
107+
_ = dist_ops.euclidean_dist_transform(image)
116108

117109
def test_all_zeros(self):
118-
image = tf.zeros([10, 10, 1], tf.uint8)
119-
expected_output = np.zeros([10, 10, 1])
110+
image = tf.zeros([10, 10], tf.uint8)
111+
expected_output = np.zeros([10, 10])
120112

121113
for output_dtype in [tf.float16, tf.float32, tf.float64]:
122-
output = distance_tranform_ops.euclidean_dist_transform(
114+
output = dist_ops.euclidean_dist_transform(
123115
image, dtype=output_dtype)
124116
self.assertAllClose(output, expected_output)
125117

126118
def test_all_ones(self):
127119
image = tf.ones([10, 10, 1], tf.uint8)
128-
output = distance_tranform_ops.euclidean_dist_transform(image)
120+
output = dist_ops.euclidean_dist_transform(image)
129121
expected_output = np.full([10, 10, 1], tf.float32.max)
130122
self.assertAllClose(output, expected_output)
131123

124+
def test_unknown_shape(self):
125+
fn = dist_ops.euclidean_dist_transform.get_concrete_function(
126+
tf.TensorSpec(None, tf.uint8))
127+
for shape in [[5, 10], [10, 7, 1], [4, 10, 10, 1]]:
128+
image = tf.zeros(shape, dtype=tf.uint8)
129+
self.assertAllClose(image, fn(image))
130+
132131

133132
if __name__ == "__main__":
134133
tf.test.main()

tensorflow_addons/image/filters.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21+
from tensorflow_addons.image import utils as img_utils
2122
from tensorflow_addons.utils import keras_utils
2223

2324

@@ -59,7 +60,8 @@ def mean_filter2d(image,
5960
"""Perform mean filtering on image(s).
6061
6162
Args:
62-
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
63+
image: Either a 2-D `Tensor` of shape `[height, width]`,
64+
a 3-D `Tensor` of shape `[height, width, channels]`,
6365
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
6466
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
6567
the height and width of the 2-D mean filter. Can be a single integer
@@ -74,16 +76,14 @@ def mean_filter2d(image,
7476
Returns:
7577
3-D or 4-D `Tensor` of the same dtype as input.
7678
Raises:
77-
ValueError: If `image` is not 3 or 4-dimensional,
79+
ValueError: If `image` is not 2, 3 or 4-dimensional,
7880
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
7981
or if `filter_shape` is invalid.
8082
"""
8183
with tf.name_scope(name or "mean_filter2d"):
8284
image = tf.convert_to_tensor(image, name="image")
83-
84-
rank = image.shape.rank
85-
if rank != 3 and rank != 4:
86-
raise ValueError("image should be either 3 or 4-dimensional.")
85+
original_ndims = img_utils.get_ndims(image)
86+
image = img_utils.to_4D_image(image)
8787

8888
if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
8989
raise ValueError(
@@ -93,10 +93,6 @@ def mean_filter2d(image,
9393
filter_shape = keras_utils.normalize_tuple(filter_shape, 2,
9494
"filter_shape")
9595

96-
# Expand to a 4-D tensor
97-
if rank == 3:
98-
image = tf.expand_dims(image, axis=0)
99-
10096
# Keep the precision if it's float;
10197
# otherwise, convert to float32 for computing.
10298
orig_dtype = image.dtype
@@ -119,11 +115,7 @@ def mean_filter2d(image,
119115

120116
output /= area
121117

122-
# Squeeze out the first axis to make sure
123-
# output has the same dimension with image.
124-
if rank == 3:
125-
output = tf.squeeze(output, axis=0)
126-
118+
output = img_utils.from_4D_image(output, original_ndims)
127119
return tf.dtypes.cast(output, orig_dtype)
128120

129121

@@ -136,7 +128,8 @@ def median_filter2d(image,
136128
"""Perform median filtering on image(s).
137129
138130
Args:
139-
image: Either a 3-D `Tensor` of shape `[height, width, channels]`,
131+
image: Either a 2-D `Tensor` of shape `[height, width]`,
132+
a 3-D `Tensor` of shape `[height, width, channels]`,
140133
or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
141134
filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying
142135
the height and width of the 2-D median filter. Can be a single integer
@@ -151,16 +144,14 @@ def median_filter2d(image,
151144
Returns:
152145
3-D or 4-D `Tensor` of the same dtype as input.
153146
Raises:
154-
ValueError: If `image` is not 3 or 4-dimensional,
147+
ValueError: If `image` is not 2, 3 or 4-dimensional,
155148
if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC",
156149
or if `filter_shape` is invalid.
157150
"""
158151
with tf.name_scope(name or "median_filter2d"):
159152
image = tf.convert_to_tensor(image, name="image")
160-
161-
rank = image.shape.rank
162-
if rank != 3 and rank != 4:
163-
raise ValueError("image should be either 3 or 4-dimensional.")
153+
original_ndims = img_utils.get_ndims(image)
154+
image = img_utils.to_4D_image(image)
164155

165156
if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]:
166157
raise ValueError(
@@ -170,10 +161,6 @@ def median_filter2d(image,
170161
filter_shape = keras_utils.normalize_tuple(filter_shape, 2,
171162
"filter_shape")
172163

173-
# Expand to a 4-D tensor
174-
if rank == 3:
175-
image = tf.expand_dims(image, axis=0)
176-
177164
image_shape = tf.shape(image)
178165
batch_size = image_shape[0]
179166
height = image_shape[1]
@@ -212,10 +199,5 @@ def median_filter2d(image,
212199
top[:, :, :, :, floor - 1] + top[:, :, :, :, ceil - 1]) / 2
213200

214201
output = tf.cast(median, image.dtype)
215-
216-
# Squeeze out the first axis to make sure
217-
# output has the same dimension with image.
218-
if rank == 3:
219-
output = tf.squeeze(output, axis=0)
220-
202+
output = img_utils.from_4D_image(output, original_ndims)
221203
return output

tensorflow_addons/image/filters_test.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,13 @@ def setUp(self):
8282
super(MeanFilter2dTest, self).setUp()
8383

8484
def test_invalid_image(self):
85-
msg = "image should be either 3 or 4-dimensional."
86-
87-
for image_shape in [(28, 28), (16, 28, 28, 1, 1)]:
85+
msg = "`image` must be 2/3/4D tensor"
86+
errors = (ValueError, tf.errors.InvalidArgumentError)
87+
for image_shape in [(1,), (16, 28, 28, 1, 1)]:
8888
with self.subTest(dim=len(image_shape)):
89-
with self.assertRaisesRegexp(ValueError, msg):
90-
mean_filter2d(tf.ones(shape=image_shape))
89+
with self.assertRaisesRegexp(errors, msg):
90+
image = tf.ones(shape=image_shape)
91+
self.evaluate(mean_filter2d(image))
9192

9293
def test_invalid_filter_shape(self):
9394
msg = ("The `filter_shape` argument must be a tuple of 2 integers.")
@@ -119,6 +120,16 @@ def test_none_channels(self):
119120
fn(tf.ones(shape=(1, 3, 3, 1)))
120121
fn(tf.ones(shape=(1, 3, 3, 3)))
121122

123+
def test_unknown_shape(self):
124+
fn = mean_filter2d.get_concrete_function(
125+
tf.TensorSpec(shape=None, dtype=tf.dtypes.float32),
126+
padding="CONSTANT",
127+
constant_values=1.)
128+
129+
for shape in [(3, 3), (3, 3, 3), (1, 3, 3, 3)]:
130+
image = tf.ones(shape=shape)
131+
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))
132+
122133
def test_reflect_padding_with_3x3_filter(self):
123134
expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.],
124135
[42. / 9., 45. / 9., 48. / 9.],
@@ -191,12 +202,13 @@ def setUp(self):
191202
super(MedianFilter2dTest, self).setUp()
192203

193204
def test_invalid_image(self):
194-
msg = "image should be either 3 or 4-dimensional."
195-
196-
for image_shape in [(28, 28), (16, 28, 28, 1, 1)]:
205+
msg = "`image` must be 2/3/4D tensor"
206+
errors = (ValueError, tf.errors.InvalidArgumentError)
207+
for image_shape in [(1,), (16, 28, 28, 1, 1)]:
197208
with self.subTest(dim=len(image_shape)):
198-
with self.assertRaisesRegexp(ValueError, msg):
199-
median_filter2d(tf.ones(shape=image_shape))
209+
with self.assertRaisesRegexp(errors, msg):
210+
image = tf.ones(shape=image_shape)
211+
self.evaluate(median_filter2d(image))
200212

201213
def test_invalid_filter_shape(self):
202214
msg = ("The `filter_shape` argument must be a tuple of 2 integers.")
@@ -228,6 +240,16 @@ def test_none_channels(self):
228240
fn(tf.ones(shape=(1, 3, 3, 1)))
229241
fn(tf.ones(shape=(1, 3, 3, 3)))
230242

243+
def test_unknown_shape(self):
244+
fn = median_filter2d.get_concrete_function(
245+
tf.TensorSpec(shape=None, dtype=tf.dtypes.float32),
246+
padding="CONSTANT",
247+
constant_values=1.)
248+
249+
for shape in [(3, 3), (3, 3, 3), (1, 3, 3, 3)]:
250+
image = tf.ones(shape=shape)
251+
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))
252+
231253
def test_reflect_padding_with_3x3_filter(self):
232254
expected_plane = tf.constant([[4, 4, 5], [5, 5, 5], [5, 6, 6]])
233255

tools/ci_testing/addons_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fi
2525

2626
set -x
2727

28-
N_JOBS=$(grep -c ^processor /proc/cpuinfo)
28+
N_JOBS=1 # Must limit GPU testing to single job to prevent OOM error.
2929

3030
echo ""
3131
echo "Bazel will use ${N_JOBS} concurrent job(s)."

0 commit comments

Comments
 (0)