Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d5ffe39
migrate dense_image_warp.py
WindQAQ Feb 25, 2019
e26549a
add test for dense_image_warp
WindQAQ Feb 25, 2019
46d9a21
import dense_image_warp
WindQAQ Feb 25, 2019
c8cf220
modify BUILD file for dense_image_warp
WindQAQ Feb 25, 2019
8f0e233
correct tf imports
WindQAQ Feb 26, 2019
c1ff0bd
v2 compatible
WindQAQ Feb 26, 2019
1aa4f89
remove unused sessions
WindQAQ Feb 26, 2019
8d334ff
update README
WindQAQ Feb 26, 2019
2901fde
make non-test methods private
WindQAQ Feb 26, 2019
be733b2
remove unused method for placeholder
WindQAQ Feb 26, 2019
47fbb60
fix style
WindQAQ Feb 27, 2019
2d1daf9
remove comment
WindQAQ Feb 27, 2019
881e5f8
nit
WindQAQ Feb 27, 2019
603264b
consolidate libraries
WindQAQ Feb 28, 2019
f78e454
tf.function on _interpolate_bilinear
WindQAQ Feb 28, 2019
6728f80
merge master
WindQAQ Mar 2, 2019
71e38ec
merge master
WindQAQ Mar 5, 2019
b21fcee
code format
WindQAQ Mar 5, 2019
965c0d0
fix list order
WindQAQ Mar 5, 2019
548999b
change test size
WindQAQ Mar 8, 2019
d71ffce
fix wrond decorators
WindQAQ Mar 8, 2019
8978199
run test_size_exception in eager mode only
WindQAQ Mar 8, 2019
e228113
add TODO
WindQAQ Mar 8, 2019
26deb1b
use assertRaisesRegexp to catch exception
WindQAQ Mar 11, 2019
a807d55
Merge remote-tracking branch 'upstream/master' into migrate_dense_ima…
WindQAQ Mar 11, 2019
dfbf86d
remove tf_test_util
WindQAQ Mar 11, 2019
c8345e7
merge master
WindQAQ Mar 12, 2019
4a1216e
remove trivial blank lines
WindQAQ Mar 12, 2019
a85fc9f
merge master
WindQAQ Mar 27, 2019
4920ee8
update table of contents
WindQAQ Mar 27, 2019
2e716c2
make interpolate_bilinear public
WindQAQ Apr 1, 2019
75c2b8c
Merge branch 'master' of https://github.com/tensorflow/addons into mi…
WindQAQ Apr 2, 2019
4841648
add trailing comma
WindQAQ Apr 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py_library(
name = "image",
srcs = ([
"__init__.py",
"dense_image_warp.py",
"distort_image_ops.py",
"transform_ops.py",
]),
Expand All @@ -17,6 +18,19 @@ py_library(
srcs_version = "PY2AND3",
)

py_test(
name = "dense_image_warp_test",
size = "small",
srcs = [
"dense_image_warp_test.py",
],
main = "dense_image_warp_test.py",
srcs_version = "PY2AND3",
deps = [
":image",
],
)

py_test(
name = "distort_image_ops_test",
size = "small",
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_addons/image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
## Maintainers
| Submodule | Maintainers | Contact Info |
|:---------- |:----------- |:--------------|
| dense_image_warp | | |
| distort_image_ops | | |
| transform_ops | | |

## Components
| Submodule | Image Processing Function | Reference |
|:---------- |:----------- |:----------- |
| dense_image_warp | dense_image_warp | |
| dense_image_warp | interpolate_bilinear | |
| distort_image_ops | adjust_hsv_in_yiq | |
| distort_image_ops | random_hsv_in_yiq | |
| transform_ops | angles_to_projective_transforms | |
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function

from tensorflow_addons.image.dense_image_warp import dense_image_warp
from tensorflow_addons.image.dense_image_warp import interpolate_bilinear
from tensorflow_addons.image.distort_image_ops import adjust_hsv_in_yiq
from tensorflow_addons.image.distort_image_ops import random_hsv_in_yiq
from tensorflow_addons.image.transform_ops import rotate
Expand Down
211 changes: 211 additions & 0 deletions tensorflow_addons/image/dense_image_warp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image warping using per-pixel flow vectors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf


@tf.function
def interpolate_bilinear(grid,
query_points,
name="interpolate_bilinear",
indexing="ij"):
"""Similar to Matlab's interp2 function.

Finds values for query points on a grid using bilinear interpolation.

Args:
grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
query_points: a 3-D float `Tensor` of N points with shape
`[batch, N, 2]`.
name: a name for the operation (optional).
indexing: whether the query points are specified as row and column (ij),
or Cartesian coordinates (xy).

Returns:
values: a 3-D `Tensor` with shape `[batch, N, channels]`

Raises:
ValueError: if the indexing mode is invalid, or if the shape of the
inputs invalid.
"""
if indexing != "ij" and indexing != "xy":
raise ValueError("Indexing mode must be \'ij\' or \'xy\'")

with tf.name_scope(name):
grid = tf.convert_to_tensor(grid)
query_points = tf.convert_to_tensor(query_points)
shape = grid.get_shape().as_list()
if len(shape) != 4:
msg = "Grid must be 4 dimensional. Received size: "
raise ValueError(msg + str(grid.get_shape()))

batch_size, height, width, channels = (tf.shape(grid)[0],
tf.shape(grid)[1],
tf.shape(grid)[2],
tf.shape(grid)[3])

shape = [batch_size, height, width, channels]
query_type = query_points.dtype
grid_type = grid.dtype

tf.debugging.assert_equal(
len(query_points.get_shape()),
3,
message="Query points must be 3 dimensional.")
tf.debugging.assert_equal(
tf.shape(query_points)[2],
2,
message="Query points must be size 2 in dim 2.")

num_queries = tf.shape(query_points)[1]

tf.debugging.assert_greater_equal(
height, 2, message="Grid height must be at least 2."),
tf.debugging.assert_greater_equal(
width, 2, message="Grid width must be at least 2.")

alphas = []
floors = []
ceils = []
index_order = [0, 1] if indexing == "ij" else [1, 0]
unstacked_query_points = tf.unstack(query_points, axis=2)

for dim in index_order:
with tf.name_scope("dim-" + str(dim)):
queries = unstacked_query_points[dim]

size_in_indexing_dimension = shape[dim + 1]

# max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
# is still a valid index into the grid.
max_floor = tf.cast(size_in_indexing_dimension - 2, query_type)
min_floor = tf.constant(0.0, dtype=query_type)
floor = tf.math.minimum(
tf.math.maximum(min_floor, tf.math.floor(queries)),
max_floor)
int_floor = tf.cast(floor, tf.dtypes.int32)
floors.append(int_floor)
ceil = int_floor + 1
ceils.append(ceil)

# alpha has the same type as the grid, as we will directly use alpha
# when taking linear combinations of pixel values from the image.
alpha = tf.cast(queries - floor, grid_type)
min_alpha = tf.constant(0.0, dtype=grid_type)
max_alpha = tf.constant(1.0, dtype=grid_type)
alpha = tf.math.minimum(
tf.math.maximum(min_alpha, alpha), max_alpha)

# Expand alpha to [b, n, 1] so we can use broadcasting
# (since the alpha values don't depend on the channel).
alpha = tf.expand_dims(alpha, 2)
alphas.append(alpha)

tf.debugging.assert_less_equal(
tf.cast(batch_size * height * width, dtype=tf.dtypes.float32),
np.iinfo(np.int32).max / 8.0,
message="The image size or batch size is sufficiently large "
"that the linearized addresses used by tf.gather "
"may exceed the int32 limit.")
flattened_grid = tf.reshape(grid,
[batch_size * height * width, channels])
batch_offsets = tf.reshape(
tf.range(batch_size) * height * width, [batch_size, 1])

# This wraps tf.gather. We reshape the image data such that the
# batch, y, and x coordinates are pulled into the first dimension.
# Then we gather. Finally, we reshape the output back. It's possible this
# code would be made simpler by using tf.gather_nd.
def gather(y_coords, x_coords, name):
with tf.name_scope("gather-" + name):
linear_coordinates = (
batch_offsets + y_coords * width + x_coords)
gathered_values = tf.gather(flattened_grid, linear_coordinates)
return tf.reshape(gathered_values,
[batch_size, num_queries, channels])

# grab the pixel values in the 4 corners around each query point
top_left = gather(floors[0], floors[1], "top_left")
top_right = gather(floors[0], ceils[1], "top_right")
bottom_left = gather(ceils[0], floors[1], "bottom_left")
bottom_right = gather(ceils[0], ceils[1], "bottom_right")

# now, do the actual interpolation
with tf.name_scope("interpolate"):
interp_top = alphas[1] * (top_right - top_left) + top_left
interp_bottom = alphas[1] * (
bottom_right - bottom_left) + bottom_left
interp = alphas[0] * (interp_bottom - interp_top) + interp_top

return interp


@tf.function
def dense_image_warp(image, flow, name="dense_image_warp"):
"""Image warping using per-pixel flow vectors.

Apply a non-linear warp to the image, where the warp is specified by a
dense flow field of offset vectors that define the correspondences of
pixel values in the output image back to locations in the source image.
Specifically, the pixel value at output[b, j, i, c] is
images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].

The locations specified by this formula do not necessarily map to an int
index. Therefore, the pixel value is obtained by bilinear
interpolation of the 4 nearest pixels around
(b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
of the image, we use the nearest pixel values at the image boundary.

Args:
image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
name: A name for the operation (optional).

Note that image and flow can be of type tf.half, tf.float32, or
tf.float64, and do not necessarily have to be the same type.

Returns:
A 4-D float `Tensor` with shape`[batch, height, width, channels]`
and same type as input image.

Raises:
ValueError: if height < 2 or width < 2 or the inputs have the wrong
number of dimensions.
"""
with tf.name_scope(name):
batch_size, height, width, channels = (tf.shape(image)[0],
tf.shape(image)[1],
tf.shape(image)[2],
tf.shape(image)[3])

# The flow is defined on the image grid. Turn the flow into a list of query
# points in the grid space.
grid_x, grid_y = tf.meshgrid(tf.range(width), tf.range(height))
stacked_grid = tf.cast(tf.stack([grid_y, grid_x], axis=2), flow.dtype)
batched_grid = tf.expand_dims(stacked_grid, axis=0)
query_points_on_grid = batched_grid - flow
query_points_flattened = tf.reshape(query_points_on_grid,
[batch_size, height * width, 2])
# Compute values at the query points, then reshape the result back to the
# image grid.
interpolated = interpolate_bilinear(image, query_points_flattened)
interpolated = tf.reshape(interpolated,
[batch_size, height, width, channels])
return interpolated
Loading