Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/image/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ py_library(
"distort_image_ops.py",
"filters.py",
"transform_ops.py",
"utils.py",
]),
data = [
"//tensorflow_addons/custom_ops/image:_distort_image_ops.so",
Expand Down Expand Up @@ -85,3 +86,16 @@ py_test(
":image",
],
)

py_test(
name = "utils_test",
size = "small",
srcs = [
"utils_test.py",
],
main = "utils_test.py",
srcs_version = "PY2AND3",
deps = [
":image",
],
)
47 changes: 9 additions & 38 deletions tensorflow_addons/image/transform_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.image import utils as img_utils
from tensorflow_addons.utils.resource_loader import get_path_to_datafile

_image_ops_so = tf.load_op_library(
Expand All @@ -40,8 +41,7 @@ def transform(images,
Args:
images: A tensor of shape (num_images, num_rows, num_columns,
num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
(num_rows, num_columns) (HW). The rank must be statically known (the
shape is not `TensorShape(None)`.
(num_rows, num_columns) (HW).
transforms: Projective transform matrix/matrices. A vector of length 8 or
tensor of size N x 8. If one row of transforms is
[a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
Expand Down Expand Up @@ -72,16 +72,8 @@ def transform(images,
transforms, name="transforms", dtype=tf.dtypes.float32)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
elif image_or_images.get_shape().ndims is None:
raise TypeError("image_or_images rank must be statically known")
elif len(image_or_images.get_shape()) == 2:
images = image_or_images[None, :, :, None]
elif len(image_or_images.get_shape()) == 3:
images = image_or_images[None, :, :, :]
elif len(image_or_images.get_shape()) == 4:
images = image_or_images
else:
raise TypeError("Images should have rank between 2 and 4.")
images = img_utils.to_4D_image(image_or_images)
original_ndims = img_utils.get_ndims(image_or_images)

if output_shape is None:
output_shape = tf.shape(images)[1:3]
Expand Down Expand Up @@ -109,12 +101,7 @@ def transform(images,
output_shape=output_shape,
transforms=transforms,
interpolation=interpolation.upper())
if len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
return output[0, :, :, :]
else:
return output
return img_utils.from_4D_image(output, original_ndims)


@tf.function
Expand Down Expand Up @@ -299,8 +286,7 @@ def rotate(images, angles, interpolation="NEAREST", name=None):
images: A tensor of shape
(num_images, num_rows, num_columns, num_channels)
(NHWC), (num_rows, num_columns, num_channels) (HWC), or
(num_rows, num_columns) (HW). The rank must be statically known (the
shape is not `TensorShape(None)`.
(num_rows, num_columns) (HW).
angles: A scalar angle to rotate all images by, or (if images has rank 4)
a vector of length num_images, with an angle for each image in the
batch.
Expand All @@ -319,28 +305,13 @@ def rotate(images, angles, interpolation="NEAREST", name=None):
image_or_images = tf.convert_to_tensor(images)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
if image_or_images.get_shape().ndims is None:
raise TypeError("image_or_images rank must be statically known")
elif len(image_or_images.get_shape()) == 2:
images = image_or_images[None, :, :, None]
elif len(image_or_images.get_shape()) == 3:
images = image_or_images[None, :, :, :]
elif len(image_or_images.get_shape()) == 4:
images = image_or_images
else:
raise TypeError("Images should have rank between 2 and 4.")
images = img_utils.to_4D_image(image_or_images)
original_ndims = img_utils.get_ndims(image_or_images)

image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None]
image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None]
output = transform(
images,
angles_to_projective_transforms(angles, image_height, image_width),
interpolation=interpolation)
if image_or_images.get_shape().ndims is None:
raise TypeError("image_or_images rank must be statically known")
elif len(image_or_images.get_shape()) == 2:
return output[0, :, :, 0]
elif len(image_or_images.get_shape()) == 3:
return output[0, :, :, :]
else:
return output
return img_utils.from_4D_image(output, original_ndims)
57 changes: 57 additions & 0 deletions tensorflow_addons/image/transform_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function

import os
import tempfile
import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -72,6 +74,15 @@ def test_transform_static_output_shape(self):
output_shape=tf.constant([3, 5]))
self.assertAllEqual([3, 5], result.shape)

@test_utils.run_in_graph_and_eager_modes
def test_transform_unknown_shape(self):
fn = transform_ops.transform.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32),
[1, 0, 0, 0, 1, 0, 0, 0])
for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3):
image = tf.ones(shape=shape)
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))

def _test_grad(self, input_shape, output_shape=None):
image_size = tf.math.cumprod(input_shape)[-1]
image_size = tf.cast(image_size, tf.float32)
Expand Down Expand Up @@ -270,6 +281,52 @@ def test_rotate_static_shape(self):
image, tf.random.uniform((), -1, 1), interpolation="BILINEAR")
self.assertEqual(image.get_shape(), result.get_shape())

def test_unknown_shape(self):
fn = transform_ops.rotate.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32), 0)
for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3):
image = tf.ones(shape=shape)
self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image)))


def preprocess_image(image):
image = tf.image.decode_image(image, channels=3)

angle = 1.57
image = transform_ops.rotate(image, angle)
print(type(image))
print(image.get_shape())

return image


def load_and_preprocess_tfa(path):
image = tf.io.read_file(path)
return preprocess_image(image)


class DatasetTests(tf.test.TestCase):
def setUp(self):
x = np.random.random(size=(24, 24, 3))

buf = tf.image.encode_jpeg(x, format='rgb', quality=100).numpy()

with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as fh:
fh.write(buf)

self.image_path = fh.name

def tearDown(self):
os.unlink(self.image_path)

@test_utils.run_v2_only
def test_rotate_static_shape(self):
path_ds = tf.data.Dataset.from_tensor_slices([self.image_path])
image_ds = path_ds.map(load_and_preprocess_tfa)
for image in image_ds:
print(type(image))
print(image.shape)


if __name__ == "__main__":
tf.test.main()
98 changes: 98 additions & 0 deletions tensorflow_addons/image/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image util ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


def get_ndims(image):
return image.get_shape().ndims or tf.rank(image)


@tf.function
def to_4D_image(image):
"""Convert 2/3/4D image to 4D image.

Args:
image: 2/3/4D tensor.

Returns:
4D tensor with the same type.
"""
tf.debugging.assert_rank_in(image, [2, 3, 4])
ndims = image.get_shape().ndims
if ndims is None:
return _dynamic_to_4D_image(image)
elif ndims == 2:
return image[None, :, :, None]
elif ndims == 3:
return image[None, :, :, :]
else:
return image


def _dynamic_to_4D_image(image):
shape = tf.shape(image)
original_rank = tf.rank(image)
# 4D image => [N, H, W, C]
# 3D image => [1, H, W, C]
# 2D image => [1, H, W, 1]
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
# yapf: disable
new_shape = tf.concat(
[tf.ones(shape=left_pad, dtype=tf.int32),
shape,
tf.ones(shape=right_pad, dtype=tf.int32)],
axis=0)
# yapf: enable
return tf.reshape(image, new_shape)


@tf.function
def from_4D_image(image, ndims):
"""Convert back to an image with `ndims` rank.

Args:
image: 4D tensor.
ndims: The original rank of the image.

Returns:
`ndims`-D tensor with the same type.
"""
tf.debugging.assert_rank(image, 4)
if isinstance(ndims, tf.Tensor):
return _dynamic_from_4D_image(image, ndims)
elif ndims == 2:
return tf.squeeze(image, [0, 3])
elif ndims == 3:
return tf.squeeze(image, [0])
else:
return image


def _dynamic_from_4D_image(image, original_rank):
shape = tf.shape(image)
# 4D image <= [N, H, W, C]
# 3D image <= [1, H, W, C]
# 2D image <= [1, H, W, 1]
begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
original_shape = shape[begin:end]
return tf.reshape(image, original_shape)
97 changes: 97 additions & 0 deletions tensorflow_addons/image/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for util ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow_addons.image import utils as img_utils
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class UtilsOpsTest(tf.test.TestCase):
def test_to_4D_image(self):
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
exp = tf.ones(shape=(1, 2, 4, 1))
res = img_utils.to_4D_image(tf.ones(shape=shape))
# static shape:
self.assertAllEqual(exp.get_shape(), res.get_shape())
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))

def test_to_4D_image_with_unknown_shape(self):
fn = img_utils.to_4D_image.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32))
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
exp = tf.ones(shape=(1, 2, 4, 1))
res = fn(tf.ones(shape=shape))
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))

def test_to_4D_image_with_invalid_shape(self):
with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
img_utils.to_4D_image(tf.ones(shape=(1,)))

with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)):
img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2)))

def test_from_4D_image(self):
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
exp = tf.ones(shape=shape)
res = img_utils.from_4D_image(
tf.ones(shape=(1, 2, 4, 1)), len(shape))
# static shape:
self.assertAllEqual(exp.get_shape(), res.get_shape())
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))

def test_from_4D_image_with_unknown_shape(self):
for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1):
exp = tf.ones(shape=shape)
fn = img_utils.from_4D_image.get_concrete_function(
tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape))
res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape))
self.assertAllEqual(self.evaluate(exp), self.evaluate(res))

def test_from_4D_image_with_invalid_data(self):
with self.assertRaises(ValueError):
self.evaluate(
img_utils.from_4D_image(tf.ones(shape=(2, 2, 4, 1)), 2))

with self.assertRaises(tf.errors.InvalidArgumentError):
self.evaluate(
img_utils.from_4D_image(
tf.ones(shape=(2, 2, 4, 1)), tf.constant(2)))

def test_from_4D_image_with_invalid_shape(self):
for rank in 2, tf.constant(2):
with self.subTest(rank=rank):
with self.assertRaises((ValueError,
tf.errors.InvalidArgumentError)):
img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank)

with self.assertRaises((ValueError,
tf.errors.InvalidArgumentError)):
img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank)

with self.assertRaises((ValueError,
tf.errors.InvalidArgumentError)):
img_utils.from_4D_image(
tf.ones(shape=(1, 2, 4, 1, 1)), rank)


if __name__ == "__main__":
tf.test.main()
Loading