diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index 2280ce10bc..52b68cf8f6 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -12,6 +12,7 @@ py_library( "filters.py", "transform_ops.py", "translate_ops.py", + "utils.py", ]), data = [ "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", @@ -99,3 +100,16 @@ py_test( ":image", ], ) + +py_test( + name = "utils_test", + size = "small", + srcs = [ + "utils_test.py", + ], + main = "utils_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/transform_ops.py b/tensorflow_addons/image/transform_ops.py index 55df62ab17..a47dd0bc10 100644 --- a/tensorflow_addons/image/transform_ops.py +++ b/tensorflow_addons/image/transform_ops.py @@ -18,6 +18,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.image import utils as img_utils from tensorflow_addons.utils.resource_loader import get_path_to_datafile _image_ops_so = tf.load_op_library( @@ -40,8 +41,7 @@ def transform(images, Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -72,16 +72,8 @@ def transform(images, transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - elif image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) if output_shape is None: output_shape = tf.shape(images)[1:3] @@ -109,12 +101,7 @@ def transform(images, output_shape=output_shape, transforms=transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) @tf.function @@ -299,8 +286,7 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. @@ -319,16 +305,8 @@ def rotate(images, angles, interpolation="NEAREST", name=None): image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] @@ -336,11 +314,4 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images, angles_to_projective_transforms(angles, image_height, image_width), interpolation=interpolation) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 5bd8b0d341..deb0c24af3 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -72,6 +72,15 @@ def test_transform_static_output_shape(self): output_shape=tf.constant([3, 5])) self.assertAllEqual([3, 5], result.shape) + @test_utils.run_in_graph_and_eager_modes + def test_transform_unknown_shape(self): + fn = transform_ops.transform.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), + [1, 0, 0, 0, 1, 0, 0, 0]) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + def _test_grad(self, input_shape, output_shape=None): image_size = tf.math.cumprod(input_shape)[-1] image_size = tf.cast(image_size, tf.float32) @@ -270,6 +279,13 @@ def test_rotate_static_shape(self): image, tf.random.uniform((), -1, 1), interpolation="BILINEAR") self.assertEqual(image.get_shape(), result.get_shape()) + def test_unknown_shape(self): + fn = transform_ops.rotate.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), 0) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py new file mode 100644 index 0000000000..036a4caaa1 --- /dev/null +++ b/tensorflow_addons/image/utils.py @@ -0,0 +1,108 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +@tf.function +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D tensor. + + Returns: + 4D tensor with the same type. + """ + # yapf:disable + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message='`image` must be 2/3/4D tensor') + ]): + # yapf: enable + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + # yapf: disable + new_shape = tf.concat( + [tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32)], + axis=0) + # yapf: enable + return tf.reshape(image, new_shape) + + +@tf.function +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D tensor. + ndims: The original rank of the image. + + Returns: + `ndims`-D tensor with the same type. + """ + # yapf:disable + with tf.control_dependencies([ + tf.debugging.assert_rank( + image, 4, message='`image` must be 4D tensor') + ]): + # yapf:enable + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py new file mode 100644 index 0000000000..1e7ddb7d13 --- /dev/null +++ b/tensorflow_addons/image/utils_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_addons.image import utils as img_utils +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class UtilsOpsTest(tf.test.TestCase): + def test_to_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=(1, 2, 4, 1)) + res = img_utils.to_4D_image(tf.ones(shape=shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) + + def test_to_4D_image_with_unknown_shape(self): + fn = img_utils.to_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=(1, 2, 4, 1)) + res = fn(tf.ones(shape=shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) + + def test_to_4D_image_with_invalid_shape(self): + errors = (ValueError, tf.errors.InvalidArgumentError) + with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'): + img_utils.to_4D_image(tf.ones(shape=(1,))) + + with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'): + img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2))) + + def test_from_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=shape) + res = img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1)), len(shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) + + def test_from_4D_image_with_unknown_shape(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=shape) + fn = img_utils.from_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) + res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) + + def test_from_4D_image_with_invalid_data(self): + with self.assertRaises(ValueError): + self.evaluate( + img_utils.from_4D_image(tf.ones(shape=(2, 2, 4, 1)), 2)) + + with self.assertRaises(tf.errors.InvalidArgumentError): + self.evaluate( + img_utils.from_4D_image( + tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) + + def test_from_4D_image_with_invalid_shape(self): + errors = (ValueError, tf.errors.InvalidArgumentError) + for rank in 2, tf.constant(2): + with self.subTest(rank=rank): + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) + + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) + + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): + img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1, 1)), rank) + + +if __name__ == "__main__": + tf.test.main()