Skip to content

Commit 4f1d422

Browse files
committed
Use public API for name_scope and convert_to_tensor
1 parent 56bf5d1 commit 4f1d422

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

tensorflow_addons/image/dense_image_warp.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tensorflow as tf
2222

2323

24-
@tf.function
2524
def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
2625
"""Similar to Matlab's interp2 function.
2726
@@ -48,31 +47,33 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
4847
with tf.name_scope(name or "interpolate_bilinear"):
4948
grid = tf.convert_to_tensor(grid)
5049
query_points = tf.convert_to_tensor(query_points)
51-
shape = grid.get_shape().as_list()
52-
if len(shape) != 4:
50+
51+
if len(grid.shape) != 4:
5352
msg = "Grid must be 4 dimensional. Received size: "
54-
raise ValueError(msg + str(grid.get_shape()))
53+
raise ValueError(msg + str(grid.shape))
54+
55+
if len(query_points.shape) != 3:
56+
raise ValueError("Query points must be 3 dimensional.")
57+
58+
grid_shape = tf.shape(grid)
59+
query_shape = tf.shape(query_points)
5560

56-
batch_size, height, width, channels = (tf.shape(grid)[0],
57-
tf.shape(grid)[1],
58-
tf.shape(grid)[2],
59-
tf.shape(grid)[3])
61+
batch_size, height, width, channels = (grid_shape[0],
62+
grid_shape[1],
63+
grid_shape[2],
64+
grid_shape[3])
6065

6166
shape = [batch_size, height, width, channels]
67+
num_queries = query_shape[1]
68+
6269
query_type = query_points.dtype
6370
grid_type = grid.dtype
6471

6572
tf.debugging.assert_equal(
66-
len(query_points.get_shape()),
67-
3,
68-
message="Query points must be 3 dimensional.")
69-
tf.debugging.assert_equal(
70-
tf.shape(query_points)[2],
73+
query_shape[2],
7174
2,
7275
message="Query points must be size 2 in dim 2.")
7376

74-
num_queries = tf.shape(query_points)[1]
75-
7677
tf.debugging.assert_greater_equal(
7778
height, 2, message="Grid height must be at least 2."),
7879
tf.debugging.assert_greater_equal(

tensorflow_addons/image/dense_image_warp_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _check_zero_flow_correctness(self, shape, image_type, flow_type):
134134

135135
self.assertAllClose(rand_image, interp)
136136

137-
# TODO: run in both graph and eager modes
137+
@test_utils.run_in_graph_and_eager_modes
138138
def test_zero_flows(self):
139139
"""Apply _check_zero_flow_correctness() for a few sizes and types."""
140140
shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]

0 commit comments

Comments
 (0)