From 45535c46901d7e5c2d06c191cb700770c2a11f09 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Thu, 25 Apr 2019 20:02:27 -0700 Subject: [PATCH 01/14] Add bzl file to control CXX11_ABI --- tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl | 1 + tensorflow_addons/custom_ops/image/BUILD | 6 ++++-- tensorflow_addons/custom_ops/seq2seq/BUILD | 4 +++- tensorflow_addons/custom_ops/text/BUILD | 4 +++- 4 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl diff --git a/tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl b/tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl new file mode 100644 index 0000000000..871a646cff --- /dev/null +++ b/tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl @@ -0,0 +1 @@ +D_GLIBCXX_USE_CXX11_ABI = "-D_GLIBCXX_USE_CXX11_ABI=0" \ No newline at end of file diff --git a/tensorflow_addons/custom_ops/image/BUILD b/tensorflow_addons/custom_ops/image/BUILD index d9ba3029f5..acf156791f 100644 --- a/tensorflow_addons/custom_ops/image/BUILD +++ b/tensorflow_addons/custom_ops/image/BUILD @@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") + cc_binary( name = "_distort_image_ops.so", srcs = [ @@ -12,7 +14,7 @@ cc_binary( copts = [ "-pthread", "-std=c++11", - "-D_GLIBCXX_USE_CXX11_ABI=0", + D_GLIBCXX_USE_CXX11_ABI, ], linkshared = 1, deps = [ @@ -33,7 +35,7 @@ cc_binary( copts = [ "-pthread", "-std=c++11", - "-D_GLIBCXX_USE_CXX11_ABI=0", + D_GLIBCXX_USE_CXX11_ABI, ], linkshared = 1, deps = [ diff --git a/tensorflow_addons/custom_ops/seq2seq/BUILD b/tensorflow_addons/custom_ops/seq2seq/BUILD index cf0849842e..5efa1db742 100644 --- a/tensorflow_addons/custom_ops/seq2seq/BUILD +++ b/tensorflow_addons/custom_ops/seq2seq/BUILD @@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") + cc_binary( name = "_beam_search_ops.so", srcs = [ @@ -13,7 +15,7 @@ cc_binary( copts = [ "-pthread", "-std=c++11", - "-D_GLIBCXX_USE_CXX11_ABI=0", + D_GLIBCXX_USE_CXX11_ABI, ], linkshared = 1, deps = [ diff --git a/tensorflow_addons/custom_ops/text/BUILD b/tensorflow_addons/custom_ops/text/BUILD index c8375f7b36..4bf34e2476 100644 --- a/tensorflow_addons/custom_ops/text/BUILD +++ b/tensorflow_addons/custom_ops/text/BUILD @@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) +load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") + cc_binary( name = "_skip_gram_ops.so", srcs = [ @@ -11,7 +13,7 @@ cc_binary( copts = [ "-pthread", "-std=c++11", - "-D_GLIBCXX_USE_CXX11_ABI=0", + D_GLIBCXX_USE_CXX11_ABI, ], linkshared = 1, deps = [ From 89855b623c8d3df5dc79e5ed58f3efc968c34e56 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Sat, 4 May 2019 08:58:27 -0700 Subject: [PATCH 02/14] Rename macros --- .../{D_GLIBCXX_USE_CXX11_ABI.bzl => addons_bazel_macros.bzl} | 0 tensorflow_addons/custom_ops/image/BUILD | 2 +- tensorflow_addons/custom_ops/seq2seq/BUILD | 2 +- tensorflow_addons/custom_ops/text/BUILD | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename tensorflow_addons/{D_GLIBCXX_USE_CXX11_ABI.bzl => addons_bazel_macros.bzl} (100%) diff --git a/tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl b/tensorflow_addons/addons_bazel_macros.bzl similarity index 100% rename from tensorflow_addons/D_GLIBCXX_USE_CXX11_ABI.bzl rename to tensorflow_addons/addons_bazel_macros.bzl diff --git a/tensorflow_addons/custom_ops/image/BUILD b/tensorflow_addons/custom_ops/image/BUILD index acf156791f..59d9267ce4 100644 --- a/tensorflow_addons/custom_ops/image/BUILD +++ b/tensorflow_addons/custom_ops/image/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("//tensorflow_addons:addons_bazel_macros.bzl", "D_GLIBCXX_USE_CXX11_ABI") cc_binary( name = "_distort_image_ops.so", diff --git a/tensorflow_addons/custom_ops/seq2seq/BUILD b/tensorflow_addons/custom_ops/seq2seq/BUILD index 5efa1db742..733db76a93 100644 --- a/tensorflow_addons/custom_ops/seq2seq/BUILD +++ b/tensorflow_addons/custom_ops/seq2seq/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("//tensorflow_addons:addons_bazel_macros.bzl", "D_GLIBCXX_USE_CXX11_ABI") cc_binary( name = "_beam_search_ops.so", diff --git a/tensorflow_addons/custom_ops/text/BUILD b/tensorflow_addons/custom_ops/text/BUILD index 4bf34e2476..8bc1db6d63 100644 --- a/tensorflow_addons/custom_ops/text/BUILD +++ b/tensorflow_addons/custom_ops/text/BUILD @@ -2,7 +2,7 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//visibility:public"]) -load("//tensorflow_addons:D_GLIBCXX_USE_CXX11_ABI.bzl", "D_GLIBCXX_USE_CXX11_ABI") +load("//tensorflow_addons:addons_bazel_macros.bzl", "D_GLIBCXX_USE_CXX11_ABI") cc_binary( name = "_skip_gram_ops.so", From 6568587ba0831b95e7df43d4b01f54694510bede Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Sat, 4 May 2019 09:23:26 -0700 Subject: [PATCH 03/14] Add tf.sysconfig lookup --- configure.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/configure.sh b/configure.sh index cc0d9d0fee..248aaa9959 100755 --- a/configure.sh +++ b/configure.sh @@ -45,7 +45,9 @@ pip install $QUIET_FLAG -r requirements.txt TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) TF_SHAREDLIB=( $(python -c 'import tensorflow as tf; print(tf.sysconfig.get_link_flags()[-1])') ) +TF_CXX11_ABI_FLAG=( $(python -c 'import tensorflow as tf; print(tf.sysconfig.CXX11_ABI_FLAG)') ) write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${TF_LFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${TF_SHAREDLIB:3} +write_action_env_to_bazelrc "TF_CXX11_ABI_FLAG" ${TF_CXX11_ABI_FLAG:1} From 8294024b1b52b201292266c864e95607344f10cd Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Sun, 5 May 2019 11:59:16 -0700 Subject: [PATCH 04/14] Add TODO note in configure --- configure.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/configure.sh b/configure.sh index 248aaa9959..d1aeec17d7 100755 --- a/configure.sh +++ b/configure.sh @@ -51,3 +51,4 @@ write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${TF_LFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${TF_SHAREDLIB:3} write_action_env_to_bazelrc "TF_CXX11_ABI_FLAG" ${TF_CXX11_ABI_FLAG:1} +# TODO: propagate TF_* variables to bazel macro file From 04eed3feb11278dff091bd76cf54d9109bba341d Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Sun, 5 May 2019 16:45:01 -0700 Subject: [PATCH 05/14] Fix --- configure.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configure.sh b/configure.sh index d1aeec17d7..ea7068dd51 100755 --- a/configure.sh +++ b/configure.sh @@ -50,5 +50,5 @@ TF_CXX11_ABI_FLAG=( $(python -c 'import tensorflow as tf; print(tf.sysconfig.CXX write_action_env_to_bazelrc "TF_HEADER_DIR" ${TF_CFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_DIR" ${TF_LFLAGS:2} write_action_env_to_bazelrc "TF_SHARED_LIBRARY_NAME" ${TF_SHAREDLIB:3} -write_action_env_to_bazelrc "TF_CXX11_ABI_FLAG" ${TF_CXX11_ABI_FLAG:1} +write_action_env_to_bazelrc "TF_CXX11_ABI_FLAG" ${TF_CXX11_ABI_FLAG} # TODO: propagate TF_* variables to bazel macro file From b698f36c04a0f0bbcb98259108af45161868ce52 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Mon, 13 May 2019 22:54:48 -0700 Subject: [PATCH 06/14] Add failing test for dataset image transforms --- tensorflow_addons/image/transform_ops_test.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 840333d879..c4cadec7bd 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import os import numpy as np import tensorflow as tf @@ -287,5 +288,44 @@ def test_rotate_static_shape(self): self.assertEqual(image.get_shape(), result.get_shape()) +def preprocess_image(image, addons): + image = tf.image.decode_image(image, channels=3) + + if addons: + angle = 1.57 + image = transform_ops.rotate(image, angle) + print(type(image)) + print(image.get_shape()) + else: + image = tf.image.rot90(image) + print(type(image)) + print(image.get_shape()) + + return image + + +def load_and_preprocess_tfa(path): + image = tf.io.read_file(path) + return preprocess_image(image, addons=True) + + +def load_and_preprocess_tf(path): + image = tf.io.read_file(path) + return preprocess_image(image, addons=False) + + +@test_utils.run_all_in_graph_and_eager_modes +class DatasetTests(tf.test.TestCase): + prefix_path = "tensorflow/core/lib" + image_path = os.path.join(prefix_path, "gif", "testdata", "scan.gif") + + def test_rotate_static_shape(self): + path_ds = tf.data.Dataset.from_tensor_slices([self.image_path]) + image_ds = path_ds.map(load_and_preprocess_tfa) + for image in image_ds: + print(type(image)) + print(image.shape) + + if __name__ == "__main__": tf.test.main() From c01ac2861b556b5be678d4fdb8005455ce3f7eb4 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Mon, 13 May 2019 23:00:30 -0700 Subject: [PATCH 07/14] Remove tf-only function --- tensorflow_addons/image/transform_ops_test.py | 20 +++++-------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index c4cadec7bd..6a6be1ca02 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -288,18 +288,13 @@ def test_rotate_static_shape(self): self.assertEqual(image.get_shape(), result.get_shape()) -def preprocess_image(image, addons): +def preprocess_image(image): image = tf.image.decode_image(image, channels=3) - if addons: - angle = 1.57 - image = transform_ops.rotate(image, angle) - print(type(image)) - print(image.get_shape()) - else: - image = tf.image.rot90(image) - print(type(image)) - print(image.get_shape()) + angle = 1.57 + image = transform_ops.rotate(image, angle) + print(type(image)) + print(image.get_shape()) return image @@ -309,11 +304,6 @@ def load_and_preprocess_tfa(path): return preprocess_image(image, addons=True) -def load_and_preprocess_tf(path): - image = tf.io.read_file(path) - return preprocess_image(image, addons=False) - - @test_utils.run_all_in_graph_and_eager_modes class DatasetTests(tf.test.TestCase): prefix_path = "tensorflow/core/lib" From 335c0300c9dac7fec6e52174736c591687f466a7 Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Mon, 13 May 2019 23:01:39 -0700 Subject: [PATCH 08/14] Remove tf-only function --- tensorflow_addons/image/transform_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 6a6be1ca02..021f367a41 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -301,7 +301,7 @@ def preprocess_image(image): def load_and_preprocess_tfa(path): image = tf.io.read_file(path) - return preprocess_image(image, addons=True) + return preprocess_image(image) @test_utils.run_all_in_graph_and_eager_modes From 21964444cbb07c0b4913f82407d5b831a0de9509 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 19:16:58 +0800 Subject: [PATCH 09/14] ENH: add *_4D_image method --- tensorflow_addons/image/BUILD | 14 ++++ tensorflow_addons/image/utils.py | 96 +++++++++++++++++++++++++++ tensorflow_addons/image/utils_test.py | 92 +++++++++++++++++++++++++ 3 files changed, 202 insertions(+) create mode 100644 tensorflow_addons/image/utils.py create mode 100644 tensorflow_addons/image/utils_test.py diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index daeb47d3a8..d09ba988d4 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -11,6 +11,7 @@ py_library( "distort_image_ops.py", "filters.py", "transform_ops.py", + "utils.py", ]), data = [ "//tensorflow_addons/custom_ops/image:_distort_image_ops.so", @@ -85,3 +86,16 @@ py_test( ":image", ], ) + +py_test( + name = "utils_test", + size = "small", + srcs = [ + "utils_test.py", + ], + main = "utils_test.py", + srcs_version = "PY2AND3", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py new file mode 100644 index 0000000000..7affc12df8 --- /dev/null +++ b/tensorflow_addons/image/utils.py @@ -0,0 +1,96 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Image util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +@tf.function +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D tensor. + + Returns: + 4D tensor with the same type. + """ + tf.debugging.assert_rank_in(image, [2, 3, 4]) + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] + # 3D image => [1, H, W, C] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.squeeze( + tf.pad( + tf.expand_dims(shape, axis=0), [[0, 0], [left_pad, right_pad]], + constant_values=1), + axis=0) + return tf.reshape(image, new_shape) + + +@tf.function +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D tensor. + ndims: The original rank of the image. + + Returns: + `ndims`-D tensor with the same type. + """ + tf.debugging.assert_rank(image, 4) + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] + # 3D image <= [1, H, W, C] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + original_shape = shape[begin:end] + return tf.reshape(image, original_shape) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py new file mode 100644 index 0000000000..89da621565 --- /dev/null +++ b/tensorflow_addons/image/utils_test.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for util ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_addons.image import utils as img_utils +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class UtilsOpsTest(tf.test.TestCase): + def test_to_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + self.assertAllEqual( + self.evaluate(tf.ones(shape=(1, 2, 4, 1))), + self.evaluate(img_utils.to_4D_image(tf.ones(shape=shape)))) + + def test_to_4D_image_with_unknown_shape(self): + fn = img_utils.to_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32)) + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + image = tf.ones(shape=shape) + self.assertAllEqual( + self.evaluate(tf.ones(shape=(1, 2, 4, 1))), + self.evaluate(fn(image))) + + def test_to_4D_image_with_invalid_shape(self): + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.to_4D_image(tf.ones(shape=(1,))) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.to_4D_image(tf.ones(shape=(1, 2, 4, 3, 2))) + + def test_from_4D_image(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate( + img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + + def test_from_4D_image_with_invalid_data(self): + with self.assertRaises(ValueError): + self.evaluate( + img_utils.from_4D_image(tf.ones(shape=(2, 2, 4, 1)), 2)) + + with self.assertRaises(tf.errors.InvalidArgumentError): + self.evaluate( + img_utils.from_4D_image( + tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) + + def test_from_4D_image_with_unknown_shape(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + fn = img_utils.from_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + + def test_from_4D_image_with_invalid_shape(self): + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(1,)), 1) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), 2) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), 2) + + with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), 2) + + +if __name__ == "__main__": + tf.test.main() From b05b3b8fcbd0fe722b765f20884339baf97ad090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 19:23:05 +0800 Subject: [PATCH 10/14] ENH: transform_ops use *_4D_image --- tensorflow_addons/image/transform_ops.py | 47 ++++--------------- tensorflow_addons/image/transform_ops_test.py | 16 +++++++ 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/tensorflow_addons/image/transform_ops.py b/tensorflow_addons/image/transform_ops.py index 55df62ab17..a47dd0bc10 100644 --- a/tensorflow_addons/image/transform_ops.py +++ b/tensorflow_addons/image/transform_ops.py @@ -18,6 +18,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.image import utils as img_utils from tensorflow_addons.utils.resource_loader import get_path_to_datafile _image_ops_so = tf.load_op_library( @@ -40,8 +41,7 @@ def transform(images, Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -72,16 +72,8 @@ def transform(images, transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - elif image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) if output_shape is None: output_shape = tf.shape(images)[1:3] @@ -109,12 +101,7 @@ def transform(images, output_shape=output_shape, transforms=transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) @tf.function @@ -299,8 +286,7 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). The rank must be statically known (the - shape is not `TensorShape(None)`. + (num_rows, num_columns) (HW). angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. @@ -319,16 +305,8 @@ def rotate(images, angles, interpolation="NEAREST", name=None): image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") + images = img_utils.to_4D_image(image_or_images) + original_ndims = img_utils.get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] @@ -336,11 +314,4 @@ def rotate(images, angles, interpolation="NEAREST", name=None): images, angles_to_projective_transforms(angles, image_height, image_width), interpolation=interpolation) - if image_or_images.get_shape().ndims is None: - raise TypeError("image_or_images rank must be statically known") - elif len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + return img_utils.from_4D_image(output, original_ndims) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 5bd8b0d341..deb0c24af3 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -72,6 +72,15 @@ def test_transform_static_output_shape(self): output_shape=tf.constant([3, 5])) self.assertAllEqual([3, 5], result.shape) + @test_utils.run_in_graph_and_eager_modes + def test_transform_unknown_shape(self): + fn = transform_ops.transform.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), + [1, 0, 0, 0, 1, 0, 0, 0]) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + def _test_grad(self, input_shape, output_shape=None): image_size = tf.math.cumprod(input_shape)[-1] image_size = tf.cast(image_size, tf.float32) @@ -270,6 +279,13 @@ def test_rotate_static_shape(self): image, tf.random.uniform((), -1, 1), interpolation="BILINEAR") self.assertEqual(image.get_shape(), result.get_shape()) + def test_unknown_shape(self): + fn = transform_ops.rotate.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), 0) + for shape in (2, 4), (2, 4, 3), (1, 2, 4, 3): + image = tf.ones(shape=shape) + self.assertAllEqual(self.evaluate(image), self.evaluate(fn(image))) + if __name__ == "__main__": tf.test.main() From 27de34d0f79175f699c0661310cbe291e4daea5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 21:44:58 +0800 Subject: [PATCH 11/14] TST: more test case --- tensorflow_addons/image/utils_test.py | 41 ++++++++++++++------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index 89da621565..cf611fa4c0 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -56,6 +56,14 @@ def test_from_4D_image(self): img_utils.from_4D_image( tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + def test_from_4D_image_with_unknown_shape(self): + for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + fn = img_utils.from_4D_image.get_concrete_function( + tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) + self.assertAllEqual( + self.evaluate(tf.ones(shape=shape)), + self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + def test_from_4D_image_with_invalid_data(self): with self.assertRaises(ValueError): self.evaluate( @@ -66,26 +74,21 @@ def test_from_4D_image_with_invalid_data(self): img_utils.from_4D_image( tf.ones(shape=(2, 2, 4, 1)), tf.constant(2))) - def test_from_4D_image_with_unknown_shape(self): - for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - fn = img_utils.from_4D_image.get_concrete_function( - tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) - def test_from_4D_image_with_invalid_shape(self): - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(1,)), 1) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(2, 4)), 2) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), 2) - - with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): - img_utils.from_4D_image(tf.ones(shape=(1, 2, 4, 1, 1)), 2) + for rank in 2, tf.constant(2): + with self.subTest(rank=rank): + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4)), rank) + + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image(tf.ones(shape=(2, 4, 1)), rank) + + with self.assertRaises((ValueError, + tf.errors.InvalidArgumentError)): + img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1, 1)), rank) if __name__ == "__main__": From 60c3d34dc4545d82bfa65fc3fc6ab40ab881a35e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 22:01:23 +0800 Subject: [PATCH 12/14] CLN: simpler way to calcualte new_shape --- tensorflow_addons/image/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/utils.py b/tensorflow_addons/image/utils.py index 7affc12df8..d614ba76b4 100644 --- a/tensorflow_addons/image/utils.py +++ b/tensorflow_addons/image/utils.py @@ -55,11 +55,13 @@ def _dynamic_to_4D_image(image): # 2D image => [1, H, W, 1] left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) - new_shape = tf.squeeze( - tf.pad( - tf.expand_dims(shape, axis=0), [[0, 0], [left_pad, right_pad]], - constant_values=1), + # yapf: disable + new_shape = tf.concat( + [tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32)], axis=0) + # yapf: enable return tf.reshape(image, new_shape) From 900177485d86402d159edb0568f6b6a8c3a26728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 4 Jul 2019 22:24:06 +0800 Subject: [PATCH 13/14] TST: static shape check --- tensorflow_addons/image/utils_test.py | 32 ++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tensorflow_addons/image/utils_test.py b/tensorflow_addons/image/utils_test.py index cf611fa4c0..1420c37c0f 100644 --- a/tensorflow_addons/image/utils_test.py +++ b/tensorflow_addons/image/utils_test.py @@ -28,18 +28,19 @@ class UtilsOpsTest(tf.test.TestCase): def test_to_4D_image(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - self.assertAllEqual( - self.evaluate(tf.ones(shape=(1, 2, 4, 1))), - self.evaluate(img_utils.to_4D_image(tf.ones(shape=shape)))) + exp = tf.ones(shape=(1, 2, 4, 1)) + res = img_utils.to_4D_image(tf.ones(shape=shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_to_4D_image_with_unknown_shape(self): fn = img_utils.to_4D_image.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32)) for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - image = tf.ones(shape=shape) - self.assertAllEqual( - self.evaluate(tf.ones(shape=(1, 2, 4, 1))), - self.evaluate(fn(image))) + exp = tf.ones(shape=(1, 2, 4, 1)) + res = fn(tf.ones(shape=shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_to_4D_image_with_invalid_shape(self): with self.assertRaises((ValueError, tf.errors.InvalidArgumentError)): @@ -50,19 +51,20 @@ def test_to_4D_image_with_invalid_shape(self): def test_from_4D_image(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate( - img_utils.from_4D_image( - tf.ones(shape=(1, 2, 4, 1)), len(shape)))) + exp = tf.ones(shape=shape) + res = img_utils.from_4D_image( + tf.ones(shape=(1, 2, 4, 1)), len(shape)) + # static shape: + self.assertAllEqual(exp.get_shape(), res.get_shape()) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_from_4D_image_with_unknown_shape(self): for shape in (2, 4), (2, 4, 1), (1, 2, 4, 1): + exp = tf.ones(shape=shape) fn = img_utils.from_4D_image.get_concrete_function( tf.TensorSpec(shape=None, dtype=tf.float32), tf.size(shape)) - self.assertAllEqual( - self.evaluate(tf.ones(shape=shape)), - self.evaluate(fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)))) + res = fn(tf.ones(shape=(1, 2, 4, 1)), tf.size(shape)) + self.assertAllEqual(self.evaluate(exp), self.evaluate(res)) def test_from_4D_image_with_invalid_data(self): with self.assertRaises(ValueError): From 002004cf893715d8d961eca97ec84d8265f173fb Mon Sep 17 00:00:00 2001 From: Kyle Beauchamp Date: Thu, 4 Jul 2019 09:34:51 -0700 Subject: [PATCH 14/14] Fixes to image test case due to missing image files --- tensorflow_addons/image/transform_ops_test.py | 17 ++++++++++++++--- tensorflow_addons/utils/test_utils.py | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/image/transform_ops_test.py b/tensorflow_addons/image/transform_ops_test.py index 6a852112c5..1a21e3dcaa 100644 --- a/tensorflow_addons/image/transform_ops_test.py +++ b/tensorflow_addons/image/transform_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +import tempfile import numpy as np import tensorflow as tf @@ -304,11 +305,21 @@ def load_and_preprocess_tfa(path): return preprocess_image(image) -@test_utils.run_all_in_graph_and_eager_modes class DatasetTests(tf.test.TestCase): - prefix_path = "tensorflow/core/lib" - image_path = os.path.join(prefix_path, "gif", "testdata", "scan.gif") + def setUp(self): + x = np.random.random(size=(24, 24, 3)) + + buf = tf.image.encode_jpeg(x, format='rgb', quality=100).numpy() + + with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as fh: + fh.write(buf) + + self.image_path = fh.name + + def tearDown(self): + os.unlink(self.image_path) + @test_utils.run_v2_only def test_rotate_static_shape(self): path_ds = tf.data.Dataset.from_tensor_slices([self.image_path]) image_ds = path_ds.map(load_and_preprocess_tfa) diff --git a/tensorflow_addons/utils/test_utils.py b/tensorflow_addons/utils/test_utils.py index d79e229909..4792515094 100644 --- a/tensorflow_addons/utils/test_utils.py +++ b/tensorflow_addons/utils/test_utils.py @@ -27,6 +27,7 @@ from tensorflow.python.framework.test_util import run_deprecated_v1 from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes from tensorflow.python.keras.testing_utils import layer_test +from tensorflow.python.framework.test_util import run_v2_only from tensorflow.python.keras import keras_parameterized # pylint: enable=unused-import