From 21964444cbb07c0b4913f82407d5b831a0de9509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 19:16:58 +0800 Subject: [PATCH 1/9] ENH: add *_4D_image method --- tensorflow_addons/image/BUILD | 14 ++++ tensorflow_addons/image/utils.py | 96 +++++++++++++++++++++++++++ tensorflow_addons/image/utils_test.py | 92 +++++++++++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 tensorflow_addons/image/utils.py create mode 100644 tensorflow_addons/image/utils_test.py diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index daeb47d3a8..d09ba988d4 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -11,6 +11,7 @@ py_library( "distort_image_ops.py", "filters.py", "transform_ops.py", + "utils.py", ]), data = [ "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", @@ -85,3 +86,16 @@ py_test( ":image", ], ) + +py_test( + name = "utils_test", + size = "small", + srcs = [ + "utils_test.py", + ], + main = "utils_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py new file mode 100644 index 0000000000..7affc12df8 --- /dev/null +++ b/tensorflow_addons/image/utils.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +@tf.function +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D tensor. + + Returns: + 4D tensor with the same type. + """ + tf.debugging.assert_rank_in(image, [2, 3, 4]) + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] + # 3D image => [1, H, W, C] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.squeeze( + tf.pad( + tf.expand_dims(shape, axis=0), [[0, 0], [left_pad, right_pad]], + constant_values=1), + axis=0) + return tf.reshape(image, new_shape) + + +@tf.function +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D tensor. + ndims: The original rank of the image. + + Returns: + `ndims`-D tensor with the same type. + """ + tf.debugging.assert_rank(image, 4) + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] + # 3D image <= [1, H, W, C] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + original_shape = shape[begin:end] + return tf.reshape(image, original_shape) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py new file mode 100644 index 0000000000..89da621565 --- /dev/null +++ b/tensorflow_addons/image/utils_test.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_addons.image import utils as img_utils +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class UtilsOpsTest(tf.test.TestCase): + def test_to_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + self.assertAllEqual( + self.evaluate(tf.ones(shape=(1, 2, 4, 1))), + self.evaluate(img_utils.to_4D_image(tf.ones(shape=shape)))) + + def test_to_4D_image_with_unknown_shape(self): + fn = img_utils.to_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + image = tf.ones(shape=shape) + self.assertAllEqual( + self.evaluate(tf.ones(shape=(1, 2, 4, 1))), + self.evaluate(fn(image))) + + def test_to_4D_image_with_invalid_shape(self): + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.to_4D_image(tf.ones(shape=(1,))) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2))) + + def test_from_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate( + img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + + def test_from_4D_image_with_invalid_data(self): + with self.assertRaises(ValueError): + self.evaluate( + img_utils.from_4D_image(tf.ones(shape=(2, 2, 4, 1)), 2)) + + with self.assertRaises(tf.errors.InvalidArgumentError): + self.evaluate( + img_utils.from_4D_image( + tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) + + def test_from_4D_image_with_unknown_shape(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + fn = img_utils.from_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + + def test_from_4D_image_with_invalid_shape(self): + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(1,)), 1) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), 2) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), 2) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), 2) + + +if __name__ == "__main__": + tf.test.main() From b05b3b8fcbd0fe722b765f20884339baf97ad090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 19:23:05 +0800 Subject: [PATCH 2/9] ENH: transform_ops use *_4D_image --- tensorflow_addons/image/transform_ops.py | 47 ++++--------------- tensorflow_addons/image/transform_ops_test.py | 16 +++++++ 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/tensorflow_addons/image/transform_ops.py b/tensorflow_addons/image/transform_ops.py index 55df62ab17..a47dd0bc10 100644 --- a/tensorflow_addons/image/transform_ops.py +++ b/tensorflow_addons/image/transform_ops.py @@ -18,6 +18,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.image import utils as img_utils from tensorflow_addons.utils.resource_loader import get_path_to_datafile _image_ops_so = tf.load_op_library( @@ -40,8 +41,7 @@ def transform(images, Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -72,16 +72,8 @@ def transform(images, transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - elif image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) if output_shape is None: output_shape = tf.shape(images)[1:3] @@ -109,12 +101,7 @@ def transform(images, output_shape=output_shape, transforms=transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) @tf.function @@ -299,8 +286,7 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. @@ -319,16 +305,8 @@ def rotate(images, angles, interpolation="NEAREST", name=None): image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] @@ -336,11 +314,4 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images, angles_to_projective_transforms(angles, image_height, image_width), interpolation=interpolation) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 5bd8b0d341..deb0c24af3 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -72,6 +72,15 @@ def test_transform_static_output_shape(self): output_shape=tf.constant([3, 5])) self.assertAllEqual([3, 5], result.shape) + @test_utils.run_in_graph_and_eager_modes + def test_transform_unknown_shape(self): + fn = transform_ops.transform.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), + [1, 0, 0, 0, 1, 0, 0, 0]) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + def _test_grad(self, input_shape, output_shape=None): image_size = tf.math.cumprod(input_shape)[-1] image_size = tf.cast(image_size, tf.float32) @@ -270,6 +279,13 @@ def test_rotate_static_shape(self): image, tf.random.uniform((), -1, 1), interpolation="BILINEAR") self.assertEqual(image.get_shape(), result.get_shape()) + def test_unknown_shape(self): + fn = transform_ops.rotate.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), 0) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + if __name__ == "__main__": tf.test.main() From 27de34d0f79175f699c0661310cbe291e4daea5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 21:44:58 +0800 Subject: [PATCH 3/9] TST: more test case --- tensorflow_addons/image/utils_test.py | 41 ++++++++++++++------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index 89da621565..cf611fa4c0 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -56,6 +56,14 @@ def test_from_4D_image(self): img_utils.from_4D_image( tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + def test_from_4D_image_with_unknown_shape(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + fn = img_utils.from_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + def test_from_4D_image_with_invalid_data(self): with self.assertRaises(ValueError): self.evaluate( @@ -66,26 +74,21 @@ def test_from_4D_image_with_invalid_data(self): img_utils.from_4D_image( tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) - def test_from_4D_image_with_unknown_shape(self): - for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - fn = img_utils.from_4D_image.get_concrete_function( - tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) - def test_from_4D_image_with_invalid_shape(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(1,)), 1) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(2, 4)), 2) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), 2) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), 2) + for rank in 2, tf.constant(2): + with self.subTest(rank=rank): + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) + + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) + + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1, 1)), rank) if __name__ == "__main__": From 60c3d34dc4545d82bfa65fc3fc6ab40ab881a35e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 22:01:23 +0800 Subject: [PATCH 4/9] CLN: simpler way to calcualte new_shape --- tensorflow_addons/image/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index 7affc12df8..d614ba76b4 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -55,11 +55,13 @@ def _dynamic_to_4D_image(image): # 2D image => [1, H, W, 1] left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) - new_shape = tf.squeeze( - tf.pad( - tf.expand_dims(shape, axis=0), [[0, 0], [left_pad, right_pad]], - constant_values=1), + # yapf: disable + new_shape = tf.concat( + [tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32)], axis=0) + # yapf: enable return tf.reshape(image, new_shape) From 900177485d86402d159edb0568f6b6a8c3a26728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 22:24:06 +0800 Subject: [PATCH 5/9] TST: static shape check --- tensorflow_addons/image/utils_test.py | 32 ++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index cf611fa4c0..1420c37c0f 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -28,18 +28,19 @@ class UtilsOpsTest(tf.test.TestCase): def test_to_4D_image(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - self.assertAllEqual( - self.evaluate(tf.ones(shape=(1, 2, 4, 1))), - self.evaluate(img_utils.to_4D_image(tf.ones(shape=shape)))) + exp = tf.ones(shape=(1, 2, 4, 1)) + res = img_utils.to_4D_image(tf.ones(shape=shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_to_4D_image_with_unknown_shape(self): fn = img_utils.to_4D_image.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32)) for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - image = tf.ones(shape=shape) - self.assertAllEqual( - self.evaluate(tf.ones(shape=(1, 2, 4, 1))), - self.evaluate(fn(image))) + exp = tf.ones(shape=(1, 2, 4, 1)) + res = fn(tf.ones(shape=shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_to_4D_image_with_invalid_shape(self): with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): @@ -50,19 +51,20 @@ def test_to_4D_image_with_invalid_shape(self): def test_from_4D_image(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate( - img_utils.from_4D_image( - tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + exp = tf.ones(shape=shape) + res = img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1)), len(shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_from_4D_image_with_unknown_shape(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=shape) fn = img_utils.from_4D_image.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_from_4D_image_with_invalid_data(self): with self.assertRaises(ValueError): From eaf339e37c932641507ef823894ae5d36df72791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 5 Jul 2019 13:13:48 +0800 Subject: [PATCH 6/9] ENH: use tf.control_dependencies --- tensorflow_addons/image/utils.py | 52 +++++++++++++++------------ tensorflow_addons/image/utils_test.py | 18 +++++----- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index d614ba76b4..2229b9b607 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -35,23 +35,26 @@ def to_4D_image(image): Returns: 4D tensor with the same type. """ - tf.debugging.assert_rank_in(image, [2, 3, 4]) - ndims = image.get_shape().ndims - if ndims is None: - return _dynamic_to_4D_image(image) - elif ndims == 2: - return image[None, :, :, None] - elif ndims == 3: - return image[None, :, :, :] - else: - return image + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message='`image` must be 2/3/4D tensor') + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image def _dynamic_to_4D_image(image): shape = tf.shape(image) original_rank = tf.rank(image) - # 4D image => [N, H, W, C] - # 3D image => [1, H, W, C] + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] # 2D image => [1, H, W, 1] left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) @@ -76,21 +79,24 @@ def from_4D_image(image, ndims): Returns: `ndims`-D tensor with the same type. """ - tf.debugging.assert_rank(image, 4) - if isinstance(ndims, tf.Tensor): - return _dynamic_from_4D_image(image, ndims) - elif ndims == 2: - return tf.squeeze(image, [0, 3]) - elif ndims == 3: - return tf.squeeze(image, [0]) - else: - return image + with tf.control_dependencies([ + tf.debugging.assert_rank( + image, 4, message='`image` must be 4D tensor') + ]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image def _dynamic_from_4D_image(image, original_rank): shape = tf.shape(image) - # 4D image <= [N, H, W, C] - # 3D image <= [1, H, W, C] + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] # 2D image <= [1, H, W, 1] begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index 1420c37c0f..1e7ddb7d13 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -43,10 +43,11 @@ def test_to_4D_image_with_unknown_shape(self): self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_to_4D_image_with_invalid_shape(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + errors = (ValueError, tf.errors.InvalidArgumentError) + with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'): img_utils.to_4D_image(tf.ones(shape=(1,))) - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + with self.assertRaisesRegexp(errors, '`image` must be 2/3/4D tensor'): img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2))) def test_from_4D_image(self): @@ -77,18 +78,19 @@ def test_from_4D_image_with_invalid_data(self): tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) def test_from_4D_image_with_invalid_shape(self): + errors = (ValueError, tf.errors.InvalidArgumentError) for rank in 2, tf.constant(2): with self.subTest(rank=rank): - with self.assertRaises((ValueError, - tf.errors.InvalidArgumentError)): + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) - with self.assertRaises((ValueError, - tf.errors.InvalidArgumentError)): + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) - with self.assertRaises((ValueError, - tf.errors.InvalidArgumentError)): + with self.assertRaisesRegexp(errors, + '`image` must be 4D tensor'): img_utils.from_4D_image( tf.ones(shape=(1, 2, 4, 1, 1)), rank) From e3b10c3f36a1ffa298e9fcb37ef61e54309ea02e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 5 Jul 2019 13:30:31 +0800 Subject: [PATCH 7/9] CLN: fix code style --- tensorflow_addons/image/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index 2229b9b607..4b64ed91f2 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -35,10 +35,12 @@ def to_4D_image(image): Returns: 4D tensor with the same type. """ + # yapf:disable with tf.control_dependencies([ - tf.debugging.assert_rank_in( - image, [2, 3, 4], message='`image` must be 2/3/4D tensor') + tf.debugging.assert_rank_in( + image, [2, 3, 4], message='`image` must be 2/3/4D tensor') ]): + # ydpf: enable ndims = image.get_shape().ndims if ndims is None: return _dynamic_to_4D_image(image) @@ -79,10 +81,12 @@ def from_4D_image(image, ndims): Returns: `ndims`-D tensor with the same type. """ + # yapf:disable with tf.control_dependencies([ - tf.debugging.assert_rank( - image, 4, message='`image` must be 4D tensor') + tf.debugging.assert_rank( + image, 4, message='`image` must be 4D tensor') ]): + # yapf:enable if isinstance(ndims, tf.Tensor): return _dynamic_from_4D_image(image, ndims) elif ndims == 2: From bc097ffb5400b9cc860263a77f8288e6cad5de49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Fri, 5 Jul 2019 13:32:30 +0800 Subject: [PATCH 8/9] CLN: rename original_shape to new_shape --- tensorflow_addons/image/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index 4b64ed91f2..bbfea74e89 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -104,5 +104,5 @@ def _dynamic_from_4D_image(image, original_rank): # 2D image <= [1, H, W, 1] begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) - original_shape = shape[begin:end] - return tf.reshape(image, original_shape) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) From fd9e020ac1923951dd2c0986b970decdda29d6cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Sat, 6 Jul 2019 07:02:10 +0800 Subject: [PATCH 9/9] CLN: fix typo --- tensorflow_addons/image/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index bbfea74e89..036a4caaa1 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -40,7 +40,7 @@ def to_4D_image(image): tf.debugging.assert_rank_in( image, [2, 3, 4], message='`image` must be 2/3/4D tensor') ]): - # ydpf: enable + # yapf: enable ndims = image.get_shape().ndims if ndims is None: return _dynamic_to_4D_image(image)