Skip to content

Commit 22abf3c

Browse files
mels630WindQAQ
authored andcommitted
Add unit test verifying tf.Dataset support for dense_image_warp (#332) (#654)
1 parent e4c974d commit 22abf3c

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

tensorflow_addons/image/dense_image_warp_test.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,16 +154,25 @@ def _check_interpolation_correctness(self,
154154
shape,
155155
image_type,
156156
flow_type,
157+
call_with_unknown_shapes=False,
157158
num_probes=5):
158159
"""Interpolate, and then assert correctness for a few query
159160
locations."""
160161
low_precision = image_type == "float16" or flow_type == "float16"
161162
rand_image, rand_flows = self._get_random_image_and_flows(
162163
shape, image_type, flow_type)
163164

164-
interp = dense_image_warp(
165-
image=tf.convert_to_tensor(rand_image),
166-
flow=tf.convert_to_tensor(rand_flows))
165+
if call_with_unknown_shapes:
166+
fn = dense_image_warp.get_concrete_function(
167+
tf.TensorSpec(shape=None, dtype=image_type),
168+
tf.TensorSpec(shape=None, dtype=flow_type))
169+
interp = fn(
170+
image=tf.convert_to_tensor(rand_image),
171+
flow=tf.convert_to_tensor(rand_flows))
172+
else:
173+
interp = dense_image_warp(
174+
image=tf.convert_to_tensor(rand_image),
175+
flow=tf.convert_to_tensor(rand_flows))
167176

168177
for _ in range(num_probes):
169178
batch_index = np.random.randint(0, shape[0])
@@ -189,6 +198,14 @@ def test_interpolation(self):
189198
self._check_interpolation_correctness(
190199
shape, im_type, flow_type)
191200

201+
def test_unknown_shapes(self):
202+
"""Apply _check_interpolation_correctness() for a few sizes and check
203+
for tf.Dataset compatibility."""
204+
shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]]
205+
for shape in shapes_to_try:
206+
self._check_interpolation_correctness(shape, "float32", "float32",
207+
True)
208+
192209
def test_gradients_exist(self):
193210
"""Check that backprop can run.
194211

0 commit comments

Comments
 (0)