Skip to content

Commit d921dfa

Browse files
WindQAQseanpmorgan
authored andcommitted
Migrate dense image warp (#53)
*ENH: migrate dense_image_warp.py
1 parent 7ef8360 commit d921dfa

File tree

5 files changed

+457
-0
lines changed

5 files changed

+457
-0
lines changed

tensorflow_addons/image/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "image",
77
srcs = ([
88
"__init__.py",
9+
"dense_image_warp.py",
910
"distort_image_ops.py",
1011
"transform_ops.py",
1112
]),
@@ -17,6 +18,19 @@ py_library(
1718
srcs_version = "PY2AND3",
1819
)
1920

21+
py_test(
22+
name = "dense_image_warp_test",
23+
size = "small",
24+
srcs = [
25+
"dense_image_warp_test.py",
26+
],
27+
main = "dense_image_warp_test.py",
28+
srcs_version = "PY2AND3",
29+
deps = [
30+
":image",
31+
],
32+
)
33+
2034
py_test(
2135
name = "distort_image_ops_test",
2236
size = "small",

tensorflow_addons/image/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
## Maintainers
44
| Submodule | Maintainers | Contact Info |
55
|:---------- |:----------- |:--------------|
6+
| dense_image_warp | | |
67
| distort_image_ops | | |
78
| transform_ops | | |
89

910
## Components
1011
| Submodule | Image Processing Function | Reference |
1112
|:---------- |:----------- |:----------- |
13+
| dense_image_warp | dense_image_warp | |
14+
| dense_image_warp | interpolate_bilinear | |
1215
| distort_image_ops | adjust_hsv_in_yiq | |
1316
| distort_image_ops | random_hsv_in_yiq | |
1417
| transform_ops | angles_to_projective_transforms | |

tensorflow_addons/image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from tensorflow_addons.image.dense_image_warp import dense_image_warp
21+
from tensorflow_addons.image.dense_image_warp import interpolate_bilinear
2022
from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq
2123
from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq
2224
from tensorflow_addons.image.transform_ops import rotate
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Image warping using per-pixel flow vectors."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import numpy as np
21+
import tensorflow as tf
22+
23+
24+
@tf.function
25+
def interpolate_bilinear(grid,
26+
query_points,
27+
name="interpolate_bilinear",
28+
indexing="ij"):
29+
"""Similar to Matlab's interp2 function.
30+
31+
Finds values for query points on a grid using bilinear interpolation.
32+
33+
Args:
34+
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
35+
query_points: a 3-D float `Tensor` of N points with shape
36+
`[batch, N, 2]`.
37+
name: a name for the operation (optional).
38+
indexing: whether the query points are specified as row and column (ij),
39+
or Cartesian coordinates (xy).
40+
41+
Returns:
42+
values: a 3-D `Tensor` with shape `[batch, N, channels]`
43+
44+
Raises:
45+
ValueError: if the indexing mode is invalid, or if the shape of the
46+
inputs invalid.
47+
"""
48+
if indexing != "ij" and indexing != "xy":
49+
raise ValueError("Indexing mode must be \'ij\' or \'xy\'")
50+
51+
with tf.name_scope(name):
52+
grid = tf.convert_to_tensor(grid)
53+
query_points = tf.convert_to_tensor(query_points)
54+
shape = grid.get_shape().as_list()
55+
if len(shape) != 4:
56+
msg = "Grid must be 4 dimensional. Received size: "
57+
raise ValueError(msg + str(grid.get_shape()))
58+
59+
batch_size, height, width, channels = (tf.shape(grid)[0],
60+
tf.shape(grid)[1],
61+
tf.shape(grid)[2],
62+
tf.shape(grid)[3])
63+
64+
shape = [batch_size, height, width, channels]
65+
query_type = query_points.dtype
66+
grid_type = grid.dtype
67+
68+
tf.debugging.assert_equal(
69+
len(query_points.get_shape()),
70+
3,
71+
message="Query points must be 3 dimensional.")
72+
tf.debugging.assert_equal(
73+
tf.shape(query_points)[2],
74+
2,
75+
message="Query points must be size 2 in dim 2.")
76+
77+
num_queries = tf.shape(query_points)[1]
78+
79+
tf.debugging.assert_greater_equal(
80+
height, 2, message="Grid height must be at least 2."),
81+
tf.debugging.assert_greater_equal(
82+
width, 2, message="Grid width must be at least 2.")
83+
84+
alphas = []
85+
floors = []
86+
ceils = []
87+
index_order = [0, 1] if indexing == "ij" else [1, 0]
88+
unstacked_query_points = tf.unstack(query_points, axis=2)
89+
90+
for dim in index_order:
91+
with tf.name_scope("dim-" + str(dim)):
92+
queries = unstacked_query_points[dim]
93+
94+
size_in_indexing_dimension = shape[dim + 1]
95+
96+
# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
97+
# is still a valid index into the grid.
98+
max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
99+
min_floor = tf.constant(0.0, dtype=query_type)
100+
floor = tf.math.minimum(
101+
tf.math.maximum(min_floor, tf.math.floor(queries)),
102+
max_floor)
103+
int_floor = tf.cast(floor, tf.dtypes.int32)
104+
floors.append(int_floor)
105+
ceil = int_floor + 1
106+
ceils.append(ceil)
107+
108+
# alpha has the same type as the grid, as we will directly use alpha
109+
# when taking linear combinations of pixel values from the image.
110+
alpha = tf.cast(queries - floor, grid_type)
111+
min_alpha = tf.constant(0.0, dtype=grid_type)
112+
max_alpha = tf.constant(1.0, dtype=grid_type)
113+
alpha = tf.math.minimum(
114+
tf.math.maximum(min_alpha, alpha), max_alpha)
115+
116+
# Expand alpha to [b, n, 1] so we can use broadcasting
117+
# (since the alpha values don't depend on the channel).
118+
alpha = tf.expand_dims(alpha, 2)
119+
alphas.append(alpha)
120+
121+
tf.debugging.assert_less_equal(
122+
tf.cast(batch_size * height * width, dtype=tf.dtypes.float32),
123+
np.iinfo(np.int32).max / 8.0,
124+
message="The image size or batch size is sufficiently large "
125+
"that the linearized addresses used by tf.gather "
126+
"may exceed the int32 limit.")
127+
flattened_grid = tf.reshape(grid,
128+
[batch_size * height * width, channels])
129+
batch_offsets = tf.reshape(
130+
tf.range(batch_size) * height * width, [batch_size, 1])
131+
132+
# This wraps tf.gather. We reshape the image data such that the
133+
# batch, y, and x coordinates are pulled into the first dimension.
134+
# Then we gather. Finally, we reshape the output back. It's possible this
135+
# code would be made simpler by using tf.gather_nd.
136+
def gather(y_coords, x_coords, name):
137+
with tf.name_scope("gather-" + name):
138+
linear_coordinates = (
139+
batch_offsets + y_coords * width + x_coords)
140+
gathered_values = tf.gather(flattened_grid, linear_coordinates)
141+
return tf.reshape(gathered_values,
142+
[batch_size, num_queries, channels])
143+
144+
# grab the pixel values in the 4 corners around each query point
145+
top_left = gather(floors[0], floors[1], "top_left")
146+
top_right = gather(floors[0], ceils[1], "top_right")
147+
bottom_left = gather(ceils[0], floors[1], "bottom_left")
148+
bottom_right = gather(ceils[0], ceils[1], "bottom_right")
149+
150+
# now, do the actual interpolation
151+
with tf.name_scope("interpolate"):
152+
interp_top = alphas[1] * (top_right - top_left) + top_left
153+
interp_bottom = alphas[1] * (
154+
bottom_right - bottom_left) + bottom_left
155+
interp = alphas[0] * (interp_bottom - interp_top) + interp_top
156+
157+
return interp
158+
159+
160+
@tf.function
161+
def dense_image_warp(image, flow, name="dense_image_warp"):
162+
"""Image warping using per-pixel flow vectors.
163+
164+
Apply a non-linear warp to the image, where the warp is specified by a
165+
dense flow field of offset vectors that define the correspondences of
166+
pixel values in the output image back to locations in the source image.
167+
Specifically, the pixel value at output[b, j, i, c] is
168+
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
169+
170+
The locations specified by this formula do not necessarily map to an int
171+
index. Therefore, the pixel value is obtained by bilinear
172+
interpolation of the 4 nearest pixels around
173+
(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
174+
of the image, we use the nearest pixel values at the image boundary.
175+
176+
Args:
177+
image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
178+
flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
179+
name: A name for the operation (optional).
180+
181+
Note that image and flow can be of type tf.half, tf.float32, or
182+
tf.float64, and do not necessarily have to be the same type.
183+
184+
Returns:
185+
A 4-D float `Tensor` with shape`[batch, height, width, channels]`
186+
and same type as input image.
187+
188+
Raises:
189+
ValueError: if height < 2 or width < 2 or the inputs have the wrong
190+
number of dimensions.
191+
"""
192+
with tf.name_scope(name):
193+
batch_size, height, width, channels = (tf.shape(image)[0],
194+
tf.shape(image)[1],
195+
tf.shape(image)[2],
196+
tf.shape(image)[3])
197+
198+
# The flow is defined on the image grid. Turn the flow into a list of query
199+
# points in the grid space.
200+
grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
201+
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
202+
batched_grid = tf.expand_dims(stacked_grid, axis=0)
203+
query_points_on_grid = batched_grid - flow
204+
query_points_flattened = tf.reshape(query_points_on_grid,
205+
[batch_size, height * width, 2])
206+
# Compute values at the query points, then reshape the result back to the
207+
# image grid.
208+
interpolated = interpolate_bilinear(image, query_points_flattened)
209+
interpolated = tf.reshape(interpolated,
210+
[batch_size, height, width, channels])
211+
return interpolated

0 commit comments

Comments
 (0)