@@ -55,31 +55,49 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
5555 if len (query_points .shape ) != 3 :
5656 raise ValueError ("Query points must be 3 dimensional." )
5757
58+ if query_points .shape [2 ] is not None and query_points .shape [2 ] != 2 :
59+ raise ValueError ("Query points must be size 2 in dim 2." )
60+
61+ if grid .shape [1 ] is not None and grid .shape [1 ] < 2 :
62+ raise ValueError ("Grid height must be at least 2." )
63+
64+ if grid .shape [2 ] is not None and grid .shape [2 ] < 2 :
65+ raise ValueError ("Grid width must be at least 2." )
66+
5867 grid_shape = tf .shape (grid )
5968 query_shape = tf .shape (query_points )
6069
6170 batch_size , height , width , channels = (grid_shape [0 ], grid_shape [1 ],
6271 grid_shape [2 ], grid_shape [3 ])
6372
6473 shape = [batch_size , height , width , channels ]
65- num_queries = query_shape [1 ]
74+
75+ # pylint: disable=bad-continuation
76+ with tf .control_dependencies ([
77+ tf .debugging .assert_equal (
78+ query_shape [2 ],
79+ 2 ,
80+ message = "Query points must be size 2 in dim 2." )
81+ ]):
82+ num_queries = query_shape [1 ]
83+ # pylint: enable=bad-continuation
6684
6785 query_type = query_points .dtype
6886 grid_type = grid .dtype
6987
70- tf . debugging . assert_equal (
71- query_shape [ 2 ], 2 , message = "Query points must be size 2 in dim 2." )
72-
73- tf . debugging . assert_greater_equal (
74- height , 2 , message = "Grid height must be at least 2." ),
75- tf . debugging . assert_greater_equal (
76- width , 2 , message = "Grid width must be at least 2." )
77-
78- alphas = []
79- floors = []
80- ceils = []
81- index_order = [ 0 , 1 ] if indexing == "ij" else [ 1 , 0 ]
82- unstacked_query_points = tf . unstack ( query_points , axis = 2 )
88+ # pylint: disable=bad-continuation
89+ with tf . control_dependencies ([
90+ tf . debugging . assert_greater_equal (
91+ height , 2 , message = "Grid height must be at least 2." ),
92+ tf . debugging . assert_greater_equal (
93+ width , 2 , message = "Grid width must be at least 2." ),
94+ ]):
95+ alphas = []
96+ floors = []
97+ ceils = []
98+ index_order = [0 , 1 ] if indexing == "ij" else [ 1 , 0 ]
99+ unstacked_query_points = tf . unstack ( query_points , axis = 2 )
100+ # pylint: enable=bad-continuation
83101
84102 for dim in index_order :
85103 with tf .name_scope ("dim-" + str (dim )):
@@ -112,16 +130,21 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
112130 alpha = tf .expand_dims (alpha , 2 )
113131 alphas .append (alpha )
114132
115- tf .debugging .assert_less_equal (
116- tf .cast (batch_size * height * width , dtype = tf .dtypes .float32 ),
117- np .iinfo (np .int32 ).max / 8.0 ,
118- message = "The image size or batch size is sufficiently large "
119- "that the linearized addresses used by tf.gather "
120- "may exceed the int32 limit." )
121- flattened_grid = tf .reshape (grid ,
122- [batch_size * height * width , channels ])
123- batch_offsets = tf .reshape (
124- tf .range (batch_size ) * height * width , [batch_size , 1 ])
133+ # pylint: disable=bad-continuation
134+ with tf .control_dependencies ([
135+ tf .debugging .assert_less_equal (
136+ tf .cast (
137+ batch_size * height * width , dtype = tf .dtypes .float32 ),
138+ np .iinfo (np .int32 ).max / 8.0 ,
139+ message = "The image size or batch size is sufficiently "
140+ "large that the linearized addresses used by tf.gather "
141+ "may exceed the int32 limit." )
142+ ]):
143+ flattened_grid = tf .reshape (
144+ grid , [batch_size * height * width , channels ])
145+ batch_offsets = tf .reshape (
146+ tf .range (batch_size ) * height * width , [batch_size , 1 ])
147+ # pylint: enable=bad-continuation
125148
126149 # This wraps tf.gather. We reshape the image data such that the
127150 # batch, y, and x coordinates are pulled into the first dimension.
0 commit comments