2121import tensorflow as tf
2222
2323
24- @tf .function
2524def interpolate_bilinear (grid , query_points , indexing = "ij" , name = None ):
2625 """Similar to Matlab's interp2 function.
2726
@@ -48,31 +47,33 @@ def interpolate_bilinear(grid, query_points, indexing="ij", name=None):
4847 with tf .name_scope (name or "interpolate_bilinear" ):
4948 grid = tf .convert_to_tensor (grid )
5049 query_points = tf .convert_to_tensor (query_points )
51- shape = grid . get_shape (). as_list ()
52- if len (shape ) != 4 :
50+
51+ if len (grid . shape ) != 4 :
5352 msg = "Grid must be 4 dimensional. Received size: "
54- raise ValueError (msg + str (grid .get_shape ()))
53+ raise ValueError (msg + str (grid .shape ))
54+
55+ if len (query_points .shape ) != 3 :
56+ raise ValueError ("Query points must be 3 dimensional." )
57+
58+ grid_shape = tf .shape (grid )
59+ query_shape = tf .shape (query_points )
5560
56- batch_size , height , width , channels = (tf . shape ( grid ) [0 ],
57- tf . shape ( grid ) [1 ],
58- tf . shape ( grid ) [2 ],
59- tf . shape ( grid ) [3 ])
61+ batch_size , height , width , channels = (grid_shape [0 ],
62+ grid_shape [1 ],
63+ grid_shape [2 ],
64+ grid_shape [3 ])
6065
6166 shape = [batch_size , height , width , channels ]
67+ num_queries = query_shape [1 ]
68+
6269 query_type = query_points .dtype
6370 grid_type = grid .dtype
6471
6572 tf .debugging .assert_equal (
66- len (query_points .get_shape ()),
67- 3 ,
68- message = "Query points must be 3 dimensional." )
69- tf .debugging .assert_equal (
70- tf .shape (query_points )[2 ],
73+ query_shape [2 ],
7174 2 ,
7275 message = "Query points must be size 2 in dim 2." )
7376
74- num_queries = tf .shape (query_points )[1 ]
75-
7677 tf .debugging .assert_greater_equal (
7778 height , 2 , message = "Grid height must be at least 2." ),
7879 tf .debugging .assert_greater_equal (
0 commit comments