@@ -154,16 +154,25 @@ def _check_interpolation_correctness(self,
154154 shape ,
155155 image_type ,
156156 flow_type ,
157+ call_with_unknown_shapes = False ,
157158 num_probes = 5 ):
158159 """Interpolate, and then assert correctness for a few query
159160 locations."""
160161 low_precision = image_type == "float16" or flow_type == "float16"
161162 rand_image , rand_flows = self ._get_random_image_and_flows (
162163 shape , image_type , flow_type )
163164
164- interp = dense_image_warp (
165- image = tf .convert_to_tensor (rand_image ),
166- flow = tf .convert_to_tensor (rand_flows ))
165+ if call_with_unknown_shapes :
166+ fn = dense_image_warp .get_concrete_function (
167+ tf .TensorSpec (shape = None , dtype = image_type ),
168+ tf .TensorSpec (shape = None , dtype = flow_type ))
169+ interp = fn (
170+ image = tf .convert_to_tensor (rand_image ),
171+ flow = tf .convert_to_tensor (rand_flows ))
172+ else :
173+ interp = dense_image_warp (
174+ image = tf .convert_to_tensor (rand_image ),
175+ flow = tf .convert_to_tensor (rand_flows ))
167176
168177 for _ in range (num_probes ):
169178 batch_index = np .random .randint (0 , shape [0 ])
@@ -189,6 +198,14 @@ def test_interpolation(self):
189198 self ._check_interpolation_correctness (
190199 shape , im_type , flow_type )
191200
201+ def test_unknown_shapes (self ):
202+ """Apply _check_interpolation_correctness() for a few sizes and check
203+ for tf.Dataset compatibility."""
204+ shapes_to_try = [[3 , 4 , 5 , 6 ], [1 , 5 , 5 , 3 ], [1 , 2 , 2 , 1 ]]
205+ for shape in shapes_to_try :
206+ self ._check_interpolation_correctness (shape , "float32" , "float32" ,
207+ True )
208+
192209 def test_gradients_exist (self ):
193210 """Check that backprop can run.
194211
0 commit comments