From d4a10feab5ab08d37e602a014dc02f3d2dc51416 Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 3 May 2019 11:09:16 +0800 Subject: [PATCH 1/9] speed up mean_filter2d with depthwise_conv2d --- tensorflow_addons/image/filters.py | 136 +++++++++--------- tensorflow_addons/image/filters_test.py | 178 ++++++++++++++++-------- tensorflow_addons/utils/keras_utils.py | 1 + 3 files changed, 189 insertions(+), 126 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index 61b5334496..173c26f607 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -18,6 +18,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow_addons.utils import keras_utils @tf.function @@ -34,82 +35,87 @@ def func2(): return tf.cond(tf.math.greater(ma, one), func2, func1) -@tf.function -def mean_filter2d(image, filter_shape=(3, 3), name=None): - """This method performs Mean Filtering on image. Filter shape can be user - given. +def _pad(image, filter_shape, mode="CONSTANT", constant_values=0): + """Explicitly pad a 4-D image.""" + assert mode in ["CONSTANT", "REFLECT", "SYMMETRIC"] + filter_height, filter_width = filter_shape + pad_top = (filter_height - 1) // 2 + pad_bottom = filter_height - 1 - pad_top + pad_left = (filter_width - 1) // 2 + pad_right = filter_width - 1 - pad_left + paddings = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]] + return tf.pad(image, paddings, mode=mode, constant_values=constant_values) - This method takes both kind of images where pixel values lie between 0 to - 255 and where it lies between 0.0 and 1.0 - Args: - image: A 3D `Tensor` of type `float32` or 'int32' or 'float64' or - 'int64 and of shape`[rows, columns, channels]` - filter_shape: Optional Argument. A tuple of 2 integers (R,C). - R is the first value is the number of rows in the filter and - C is the second value in the filter is the number of columns - in the filter. This creates a filter of shape (R,C) or RxC - filter. Default value = (3,3) +@tf.function +def mean_filter2d(image, + filter_shape=(3, 3), + padding="REFLECT", + constant_values=0, + name=None): + """Perform mean filtering on image(s). - Returns: - A 3D mean filtered image tensor of shape [rows,columns,channels] and - type 'int32'. Pixel value of returned tensor ranges between 0 to 255 + Args: + image: Either a 3-D `Tensor` of shape `[height, width, channels]`, + or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. + filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying + the height and width of the 2-D mean filter. Can be a single integer + to specify the same value for all spatial dimensions. + padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". + The type of padding algorithm to use. + constant_values: A `scalar`, the pad value to use in "CONSTANT" + padding mode. + name: A name for this operation (optional). + Returns: + 3-D or 4-D float `Tensor`, as per the input. + Raises: + ValueError: If `image` is not 3 or 4-dimensional, + if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC", + or if `filter_shape` is invalid. """ - with tf.name_scope(name or "mean_filter2d"): - if not isinstance(filter_shape, tuple): - raise TypeError('Filter shape must be a tuple') - if len(filter_shape) != 2: - raise ValueError('Filter shape must be a tuple of 2 integers. ' - 'Got %s values in tuple' % len(filter_shape)) - filter_shapex = filter_shape[0] - filter_shapey = filter_shape[1] - if not isinstance(filter_shapex, int) or not isinstance( - filter_shapey, int): - raise TypeError('Size of the filter must be Integers') - (row, col, ch) = (image.shape[0], image.shape[1], image.shape[2]) - if row != None and col != None and ch != None: - (row, col, ch) = (int(row), int(col), int(ch)) - else: - raise TypeError( - 'All the Dimensions of the input image tensor must be \ - Integers.') - if row < filter_shapex or col < filter_shapey: + image = tf.convert_to_tensor(image, name="image") + + rank = image.shape.rank + if rank != 3 and rank != 4: + raise ValueError("image should be either 3 or 4-dimensional.") + + if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]: raise ValueError( - 'Number of Pixels in each dimension of the image should be \ - more than the filter size. Got filter_shape (%sx' % - filter_shape[0] + '%s).' % filter_shape[1] + - ' Image Shape (%s)' % image.shape) - if filter_shapex % 2 == 0 or filter_shapey % 2 == 0: - raise ValueError('Filter size should be odd. Got filter_shape (%sx' - % filter_shape[0] + '%s)' % filter_shape[1]) - image = tf.cast(image, tf.float32) - tf_i = tf.reshape(image, [row * col * ch]) - ma = tf.math.reduce_max(tf_i) - image = _normalize(image, ma) + "padding should be one of \"REFLECT\", \"CONSTANT\", or " + "\"SYMMETRIC\".") - # k and l is the Zero-padding size + filter_shape = keras_utils.conv_utils.normalize_tuple( + filter_shape, 2, "filter_shape") - listi = [] - for a in range(ch): - img = image[:, :, a:a + 1] - img = tf.reshape(img, [1, row, col, 1]) - slic = tf.image.extract_patches( - img, [1, filter_shapex, filter_shapey, 1], [1, 1, 1, 1], - [1, 1, 1, 1], - padding='SAME') - li = tf.reduce_mean(slic, axis=-1) - li = tf.reshape(li, [row, col, 1]) - listi.append(li) - y = tf.concat(listi[0], 2) + # Expand to a 4-D tensor + if rank == 3: + image = tf.expand_dims(image, axis=0) - for i in range(len(listi) - 1): - y = tf.concat([y, listi[i + 1]], 2) + # Keep the precision if it's float; + # otherwise, convert to float32 for computing. + if not image.dtype.is_floating: + image = tf.dtypes.cast(image, tf.dtypes.float32) - y *= 255 - y = tf.cast(y, tf.int32) + # Explicitly pad the image + image = _pad( + image, filter_shape, mode=padding, constant_values=constant_values) - return y + # Filter of shape (filter_width, filter_height, in_channels, 1) + # has value 1 / (filter_width * filter_height) for each element. + area = filter_shape[0] * filter_shape[1] + filter_shape = filter_shape + (tf.shape(image)[-1], 1) + kernel = tf.ones(shape=filter_shape, dtype=image.dtype) / area + + output = tf.nn.depthwise_conv2d( + image, kernel, strides=(1, 1, 1, 1), padding="VALID") + + # Squeeze out the first axis to make sure + # output has the same dimension with image. + if rank == 3: + output = tf.squeeze(output, axis=0) + + return output @tf.function diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index c4a3bb4bab..42e801a2e2 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -23,73 +23,129 @@ from tensorflow_addons.utils import test_utils +@test_utils.run_all_in_graph_and_eager_modes class MeanFilter2dTest(tf.test.TestCase): - def _validate_mean_filter2d(self, - inputs, - expected_values, - filter_shape=(3, 3)): - output = mean_filter2d(inputs, filter_shape) - self.assertAllClose(output, expected_values) - - @test_utils.run_in_graph_and_eager_modes - def test_filter_tuple(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) - - for filter_shape in [3, 3.5, 'dt', None]: - with self.assertRaisesRegexp(TypeError, - 'Filter shape must be a tuple'): - mean_filter2d(tf_img, filter_shape) + def test_invalid_image(self): + msg = "image should be either 3 or 4-dimensional." - filter_shape = (3, 3, 3) - msg = ('Filter shape must be a tuple of 2 integers. ' - 'Got %s values in tuple' % len(filter_shape)) - with self.assertRaisesRegexp(ValueError, msg): - mean_filter2d(tf_img, filter_shape) + for image_shape in [(28, 28), (16, 28, 28, 1, 1)]: + with self.subTest(dim=len(image_shape)): + with self.assertRaisesRegexp(ValueError, msg): + mean_filter2d(tf.ones(shape=image_shape)) - msg = 'Size of the filter must be Integers' - for filter_shape in [(3.5, 3), (None, 3)]: - with self.assertRaisesRegexp(TypeError, msg): - mean_filter2d(tf_img, filter_shape) + def test_invalid_filter_shape(self): + msg = ("The `filter_shape` argument must be a tuple of " "2 integers.") + image = tf.ones(shape=(1, 28, 28, 1)) - @test_utils.run_in_graph_and_eager_modes - def test_filter_value(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) + for filter_shape in [(3, 3, 3), (3, None, 3), None]: + with self.subTest(filter_shape=filter_shape): + with self.assertRaisesRegexp(ValueError, msg): + mean_filter2d(image, filter_shape=filter_shape) - with self.assertRaises(ValueError): - mean_filter2d(tf_img, (4, 3)) + def test_invalid_padding(self): + msg = ("padding should be one of \"REFLECT\", \"CONSTANT\", " + "or \"SYMMETRIC\".") + image = tf.ones(shape=(1, 28, 28, 1)) - @test_utils.run_deprecated_v1 - def test_dimension(self): - for image_shape in [(3, 4, None), (3, None, 4), (None, 3, 4)]: - with self.assertRaises(TypeError): - tf_img = tf.compat.v1.placeholder(tf.int32, shape=image_shape) - mean_filter2d(tf_img) - - @test_utils.run_in_graph_and_eager_modes - def test_image_vs_filter(self): - tf_img = tf.zeros([3, 4, 3], tf.int32) - filter_shape = (3, 5) - with self.assertRaises(ValueError): - mean_filter2d(tf_img, filter_shape) - - @test_utils.run_in_graph_and_eager_modes - def test_three_channels(self): - tf_img = [[[0.32801723, 0.08863795, 0.79119259], - [0.35526001, 0.79388736, 0.55435993], - [0.11607035, 0.55673079, 0.99473371]], - [[0.53240645, 0.74684819, 0.33700031], - [0.01760473, 0.28181609, 0.9751476], - [0.01605137, 0.8292904, 0.56405609]], - [[0.57215374, 0.10155051, 0.64836128], - [0.36533048, 0.91401874, 0.02524159], - [0.56379134, 0.9028874, 0.19505117]]] - - tf_img = tf.convert_to_tensor(value=tf_img) - expt = [[[34, 54, 75], [38, 93, 119], [14, 69, 87]], - [[61, 82, 94], [81, 147, 144], [40, 121, 93]], - [[42, 57, 56], [58, 106, 77], [27, 82, 49]]] - expt = tf.convert_to_tensor(value=expt) - self._validate_mean_filter2d(tf_img, expt) + with self.assertRaisesRegexp(ValueError, msg): + mean_filter2d(image, padding="TEST") + + def test_3d_image(self): + # Test shape (3, 3, 1) + image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + shape=(3, 3, 1)) + + # (3, 3) filter shape: + # reflected padding with 1 pixel to each direction + # 5 4 5 6 5 + # 2 1 2 3 2 + # 5 4 5 6 5 + # 8 7 8 9 8 + # 5 4 5 6 5 + + expected = tf.constant( + [[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], + [51. / 9., 54. / 9., 57. / 9.]], + shape=(3, 3, 1)) + + output = mean_filter2d(image) + self.assertAllClose(output, expected) + + # Test shape (3, 3, 3) + image = tf.stack([image, 2. * image, 3. * image], axis=-1) + expected = tf.stack([expected, 2. * expected, 3. * expected], axis=-1) + # Squeeze shape from (3, 3, 1, 3) to (3, 3, 3) + image = tf.squeeze(image, axis=-2) + expected = tf.squeeze(expected, axis=-2) + + output = mean_filter2d(image) + self.assertAllClose(output, expected) + + def test_4d_image(self): + # Test shape (2, 3, 3, 1) + image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + shape=(3, 3, 1)) + + expected = tf.constant( + [[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], + [51. / 9., 54. / 9., 57. / 9.]], + shape=(3, 3, 1)) + + # Batch size = 2, shape = (2, 3, 3, 1) + image = tf.stack([image, 2. * image], axis=0) + expected = tf.stack([expected, 2. * expected], axis=0) + + output = mean_filter2d(image) + self.assertAllClose(output, expected) + + # Test shape (2, 3, 3, 3) + image = tf.stack([image, image, image], axis=-1) + expected = tf.stack([expected, expected, expected], axis=-1) + image = tf.squeeze(image, axis=-2) + expected = tf.squeeze(expected, axis=-2) + + output = mean_filter2d(image) + self.assertAllClose(output, expected) + + def test_zero_padding(self): + image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + shape=(3, 3, 1)) + + # (3, 3) filter shape: + # zero padding with 1 pixel to each direction + # 0 0 0 0 0 + # 0 1 2 3 0 + # 0 4 5 6 0 + # 0 7 8 9 0 + # 0 0 0 0 0 + + expected = tf.constant( + [[12. / 9., 21. / 9., 16. / 9.], [27. / 9., 45. / 9., 33. / 9.], + [24. / 9., 39. / 9., 28. / 9.]], + shape=(3, 3, 1)) + + output = mean_filter2d(image, padding="CONSTANT", constant_values=0) + self.assertAllClose(output, expected) + + def test_symmetric_padding(self): + image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + shape=(3, 3, 1)) + + # (3, 3) filter shape: + # symmetric padding with 1 pixel to each direction + # 1 1 2 3 3 + # 1 1 2 3 3 + # 4 4 5 6 6 + # 7 7 8 9 9 + # 7 7 8 9 9 + + expected = tf.constant( + [[21. / 9., 27. / 9., 33. / 9.], [39. / 9., 45. / 9., 51. / 9.], + [57. / 9., 63. / 9., 69. / 9.]], + shape=(3, 3, 1)) + + output = mean_filter2d(image, padding="SYMMETRIC") + self.assertAllClose(output, expected) class MedianFilter2dTest(tf.test.TestCase): diff --git a/tensorflow_addons/utils/keras_utils.py b/tensorflow_addons/utils/keras_utils.py index a738e406ff..c3179a20dc 100644 --- a/tensorflow_addons/utils/keras_utils.py +++ b/tensorflow_addons/utils/keras_utils.py @@ -21,6 +21,7 @@ # TODO: find public API alternative to these from tensorflow.python.keras.losses import LossFunctionWrapper # pylint: disable=unused-import +from tensorflow.python.keras.utils import conv_utils # pylint: disable=unused-import def register_keras_custom_object(cls): From 51b18f626adce8724acfca713268bad8218c7fda Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 3 May 2019 11:31:23 +0800 Subject: [PATCH 2/9] cast the output back to the original dtype --- tensorflow_addons/image/filters.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index 173c26f607..d89d4d83c0 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -67,7 +67,7 @@ def mean_filter2d(image, padding mode. name: A name for this operation (optional). Returns: - 3-D or 4-D float `Tensor`, as per the input. + 3-D or 4-D `Tensor` of the same dtype as input. Raises: ValueError: If `image` is not 3 or 4-dimensional, if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC", @@ -94,6 +94,7 @@ def mean_filter2d(image, # Keep the precision if it's float; # otherwise, convert to float32 for computing. + orig_dtype = image.dtype if not image.dtype.is_floating: image = tf.dtypes.cast(image, tf.dtypes.float32) @@ -115,7 +116,7 @@ def mean_filter2d(image, if rank == 3: output = tf.squeeze(output, axis=0) - return output + return tf.dtypes.cast(output, orig_dtype) @tf.function From ea10aac870f82f01fedac9d0dfe848acb42b00dd Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 24 May 2019 15:52:23 +0800 Subject: [PATCH 3/9] refactor test cases --- tensorflow_addons/image/filters_test.py | 192 ++++++++++++------------ 1 file changed, 97 insertions(+), 95 deletions(-) diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index 42e801a2e2..365c666a3e 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -25,6 +25,50 @@ @test_utils.run_all_in_graph_and_eager_modes class MeanFilter2dTest(tf.test.TestCase): + def _tile_image(self, plane, image_shape): + assert 3 <= len(image_shape) <= 4 + plane = tf.convert_to_tensor(plane) + plane = tf.expand_dims(plane, -1) + channels = image_shape[-1] + image = tf.tile(plane, (1, 1, channels)) + + if len(image_shape) == 4: + batch_size = image_shape[0] + image = tf.expand_dims(image, 0) + image = tf.tile(image, (batch_size, 1, 1, 1)) + + return image + + def _setup_values(self, image_shape, filter_shape, padding, + constant_values, dtype): + assert 3 <= len(image_shape) <= 4 + height, width = image_shape[-3], image_shape[-2] + plane = tf.constant([x for x in range(1, height * width + 1)], + shape=(height, width), + dtype=dtype) + image = self._tile_image(plane, image_shape=image_shape) + + result = mean_filter2d( + image, + filter_shape=filter_shape, + padding=padding, + constant_values=constant_values) + + return result + + def _verify_values(self, image_shape, filter_shape, padding, + constant_values, expected_plane): + + expected_output = self._tile_image(expected_plane, image_shape) + dtypes = tf.dtypes + for dtype in [ + dtypes.uint8, dtypes.float16, dtypes.float32, dtypes.float64 + ]: + result = self._setup_values(image_shape, filter_shape, padding, + constant_values, dtype) + self.assertAllCloseAccordingToType( + result, tf.dtypes.cast(expected_output, dtype)) + def test_invalid_image(self): msg = "image should be either 3 or 4-dimensional." @@ -34,7 +78,7 @@ def test_invalid_image(self): mean_filter2d(tf.ones(shape=image_shape)) def test_invalid_filter_shape(self): - msg = ("The `filter_shape` argument must be a tuple of " "2 integers.") + msg = ("The `filter_shape` argument must be a tuple of 2 integers.") image = tf.ones(shape=(1, 28, 28, 1)) for filter_shape in [(3, 3, 3), (3, None, 3), None]: @@ -50,102 +94,60 @@ def test_invalid_padding(self): with self.assertRaisesRegexp(ValueError, msg): mean_filter2d(image, padding="TEST") - def test_3d_image(self): - # Test shape (3, 3, 1) - image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - shape=(3, 3, 1)) - - # (3, 3) filter shape: - # reflected padding with 1 pixel to each direction - # 5 4 5 6 5 - # 2 1 2 3 2 - # 5 4 5 6 5 - # 8 7 8 9 8 - # 5 4 5 6 5 - - expected = tf.constant( - [[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], - [51. / 9., 54. / 9., 57. / 9.]], - shape=(3, 3, 1)) - - output = mean_filter2d(image) - self.assertAllClose(output, expected) - - # Test shape (3, 3, 3) - image = tf.stack([image, 2. * image, 3. * image], axis=-1) - expected = tf.stack([expected, 2. * expected, 3. * expected], axis=-1) - # Squeeze shape from (3, 3, 1, 3) to (3, 3, 3) - image = tf.squeeze(image, axis=-2) - expected = tf.squeeze(expected, axis=-2) - - output = mean_filter2d(image) - self.assertAllClose(output, expected) - - def test_4d_image(self): - # Test shape (2, 3, 3, 1) - image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - shape=(3, 3, 1)) - - expected = tf.constant( - [[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], - [51. / 9., 54. / 9., 57. / 9.]], - shape=(3, 3, 1)) - - # Batch size = 2, shape = (2, 3, 3, 1) - image = tf.stack([image, 2. * image], axis=0) - expected = tf.stack([expected, 2. * expected], axis=0) - - output = mean_filter2d(image) - self.assertAllClose(output, expected) - - # Test shape (2, 3, 3, 3) - image = tf.stack([image, image, image], axis=-1) - expected = tf.stack([expected, expected, expected], axis=-1) - image = tf.squeeze(image, axis=-2) - expected = tf.squeeze(expected, axis=-2) - - output = mean_filter2d(image) - self.assertAllClose(output, expected) - - def test_zero_padding(self): - image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - shape=(3, 3, 1)) - - # (3, 3) filter shape: - # zero padding with 1 pixel to each direction - # 0 0 0 0 0 - # 0 1 2 3 0 - # 0 4 5 6 0 - # 0 7 8 9 0 - # 0 0 0 0 0 - - expected = tf.constant( - [[12. / 9., 21. / 9., 16. / 9.], [27. / 9., 45. / 9., 33. / 9.], - [24. / 9., 39. / 9., 28. / 9.]], - shape=(3, 3, 1)) - - output = mean_filter2d(image, padding="CONSTANT", constant_values=0) - self.assertAllClose(output, expected) + def test_reflect_padding(self): + expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.], + [42. / 9., 45. / 9., 48. / 9.], + [51. / 9., 54. / 9., 57. / 9.]]) + + for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), + (2, 3, 3, 1), (2, 3, 3, 3)]: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="REFLECT", + constant_values=0, + expected_plane=expected_plane) + + def test_constant_padding(self): + expected_plane = tf.constant([[12. / 9., 21. / 9., 16. / 9.], + [27. / 9., 45. / 9., 33. / 9.], + [24. / 9., 39. / 9., 28. / 9.]]) + + for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), + (2, 3, 3, 1), (2, 3, 3, 3)]: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="CONSTANT", + constant_values=0, + expected_plane=expected_plane) + + expected_plane = tf.constant([[17. / 9., 24. / 9., 21. / 9.], + [30. / 9., 45. / 9., 36. / 9.], + [29. / 9., 42. / 9., 33. / 9.]]) + + for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), + (2, 3, 3, 1), (2, 3, 3, 3)]: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="CONSTANT", + constant_values=1, + expected_plane=expected_plane) def test_symmetric_padding(self): - image = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], - shape=(3, 3, 1)) - - # (3, 3) filter shape: - # symmetric padding with 1 pixel to each direction - # 1 1 2 3 3 - # 1 1 2 3 3 - # 4 4 5 6 6 - # 7 7 8 9 9 - # 7 7 8 9 9 - - expected = tf.constant( - [[21. / 9., 27. / 9., 33. / 9.], [39. / 9., 45. / 9., 51. / 9.], - [57. / 9., 63. / 9., 69. / 9.]], - shape=(3, 3, 1)) - - output = mean_filter2d(image, padding="SYMMETRIC") - self.assertAllClose(output, expected) + expected_plane = tf.constant([[21. / 9., 27. / 9., 33. / 9.], + [39. / 9., 45. / 9., 51. / 9.], + [57. / 9., 63. / 9., 69. / 9.]]) + + for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), + (2, 3, 3, 1), (2, 3, 3, 3)]: + self._verify_values( + image_shape=image_shape, + filter_shape=(3, 3), + padding="SYMMETRIC", + constant_values=0, + expected_plane=expected_plane) class MedianFilter2dTest(tf.test.TestCase): From 498d3656922fd574bf4db148e664d03e328f5324 Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 24 May 2019 15:52:50 +0800 Subject: [PATCH 4/9] avoid loss of precision --- tensorflow_addons/image/filters.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index d89d4d83c0..383292e6c8 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -103,14 +103,17 @@ def mean_filter2d(image, image, filter_shape, mode=padding, constant_values=constant_values) # Filter of shape (filter_width, filter_height, in_channels, 1) - # has value 1 / (filter_width * filter_height) for each element. - area = filter_shape[0] * filter_shape[1] + # has the value of 1 for each element. + area = tf.constant( + filter_shape[0] * filter_shape[1], dtype=image.dtype) filter_shape = filter_shape + (tf.shape(image)[-1], 1) - kernel = tf.ones(shape=filter_shape, dtype=image.dtype) / area + kernel = tf.ones(shape=filter_shape, dtype=image.dtype) output = tf.nn.depthwise_conv2d( image, kernel, strides=(1, 1, 1, 1), padding="VALID") + output /= area + # Squeeze out the first axis to make sure # output has the same dimension with image. if rank == 3: From 39ae5a3dd5c49c6f25935c37ff7dbdff3642dbe1 Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 24 May 2019 16:05:28 +0800 Subject: [PATCH 5/9] add test case with channels of None --- tensorflow_addons/image/filters_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index 365c666a3e..1702033640 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -94,6 +94,19 @@ def test_invalid_padding(self): with self.assertRaisesRegexp(ValueError, msg): mean_filter2d(image, padding="TEST") + def test_none_channels(self): + # 3-D image + fn = mean_filter2d.get_concrete_function( + tf.TensorSpec(dtype=tf.dtypes.float32, shape=(3, 3, None))) + fn(tf.random.uniform(shape=(3, 3, 1))) + fn(tf.random.uniform(shape=(3, 3, 3))) + + # 4-D image + fn = mean_filter2d.get_concrete_function( + tf.TensorSpec(dtype=tf.dtypes.float32, shape=(1, 3, 3, None))) + fn(tf.random.uniform(shape=(1, 3, 3, 1))) + fn(tf.random.uniform(shape=(1, 3, 3, 3))) + def test_reflect_padding(self): expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], From 2fddf16ba9521529978b7723eafb78164ff4a70b Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Fri, 24 May 2019 16:19:10 +0800 Subject: [PATCH 6/9] add doc of _tile_image --- tensorflow_addons/image/filters_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index 1702033640..3aa4f85a84 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -26,6 +26,7 @@ @test_utils.run_all_in_graph_and_eager_modes class MeanFilter2dTest(tf.test.TestCase): def _tile_image(self, plane, image_shape): + "Tile a 2-D image `plane` into 3-D or 4-D as per `image_shape`." assert 3 <= len(image_shape) <= 4 plane = tf.convert_to_tensor(plane) plane = tf.expand_dims(plane, -1) From 8e690834fe1bfa610a67d4824dc8398823dcea0d Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Mon, 27 May 2019 20:33:31 +0800 Subject: [PATCH 7/9] use ones instead of random data --- tensorflow_addons/image/filters_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index 3aa4f85a84..d4a2e2ed09 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -99,14 +99,14 @@ def test_none_channels(self): # 3-D image fn = mean_filter2d.get_concrete_function( tf.TensorSpec(dtype=tf.dtypes.float32, shape=(3, 3, None))) - fn(tf.random.uniform(shape=(3, 3, 1))) - fn(tf.random.uniform(shape=(3, 3, 3))) + fn(tf.ones(shape=(3, 3, 1))) + fn(tf.ones(shape=(3, 3, 3))) # 4-D image fn = mean_filter2d.get_concrete_function( tf.TensorSpec(dtype=tf.dtypes.float32, shape=(1, 3, 3, None))) - fn(tf.random.uniform(shape=(1, 3, 3, 1))) - fn(tf.random.uniform(shape=(1, 3, 3, 3))) + fn(tf.ones(shape=(1, 3, 3, 1))) + fn(tf.ones(shape=(1, 3, 3, 3))) def test_reflect_padding(self): expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.], From 94389b8680ae35516ed40e102d31af1f3f0f9394 Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Mon, 27 May 2019 21:34:18 +0800 Subject: [PATCH 8/9] add test case with 4x4 filter --- tensorflow_addons/image/filters_test.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/filters_test.py b/tensorflow_addons/image/filters_test.py index d4a2e2ed09..8e8e29011e 100644 --- a/tensorflow_addons/image/filters_test.py +++ b/tensorflow_addons/image/filters_test.py @@ -26,7 +26,7 @@ @test_utils.run_all_in_graph_and_eager_modes class MeanFilter2dTest(tf.test.TestCase): def _tile_image(self, plane, image_shape): - "Tile a 2-D image `plane` into 3-D or 4-D as per `image_shape`." + """Tile a 2-D image `plane` into 3-D or 4-D as per `image_shape`.""" assert 3 <= len(image_shape) <= 4 plane = tf.convert_to_tensor(plane) plane = tf.expand_dims(plane, -1) @@ -108,7 +108,7 @@ def test_none_channels(self): fn(tf.ones(shape=(1, 3, 3, 1))) fn(tf.ones(shape=(1, 3, 3, 3))) - def test_reflect_padding(self): + def test_reflect_padding_with_3x3_filter(self): expected_plane = tf.constant([[33. / 9., 36. / 9., 39. / 9.], [42. / 9., 45. / 9., 48. / 9.], [51. / 9., 54. / 9., 57. / 9.]]) @@ -122,7 +122,21 @@ def test_reflect_padding(self): constant_values=0, expected_plane=expected_plane) - def test_constant_padding(self): + def test_reflect_padding_with_4x4_filter(self): + expected_plane = tf.constant([[80. / 16., 80. / 16., 80. / 16.], + [80. / 16., 80. / 16., 80. / 16.], + [80. / 16., 80. / 16., 80. / 16.]]) + + for image_shape in [(3, 3, 1), (3, 3, 3), (1, 3, 3, 1), (1, 3, 3, 3), + (2, 3, 3, 1), (2, 3, 3, 3)]: + self._verify_values( + image_shape=image_shape, + filter_shape=(4, 4), + padding="REFLECT", + constant_values=0, + expected_plane=expected_plane) + + def test_constant_padding_with_3x3_filter(self): expected_plane = tf.constant([[12. / 9., 21. / 9., 16. / 9.], [27. / 9., 45. / 9., 33. / 9.], [24. / 9., 39. / 9., 28. / 9.]]) @@ -149,7 +163,7 @@ def test_constant_padding(self): constant_values=1, expected_plane=expected_plane) - def test_symmetric_padding(self): + def test_symmetric_padding_with_3x3_filter(self): expected_plane = tf.constant([[21. / 9., 27. / 9., 33. / 9.], [39. / 9., 45. / 9., 51. / 9.], [57. / 9., 63. / 9., 69. / 9.]]) From e1aca51a7dc469f38220e2fcf6a32d434a62e829 Mon Sep 17 00:00:00 2001 From: WindQAQ Date: Tue, 28 May 2019 19:15:41 +0800 Subject: [PATCH 9/9] add doc related to padding --- tensorflow_addons/image/filters.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index 383292e6c8..b567d538b6 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -36,7 +36,24 @@ def func2(): def _pad(image, filter_shape, mode="CONSTANT", constant_values=0): - """Explicitly pad a 4-D image.""" + """Explicitly pad a 4-D image. + + Equivalent to the implicit padding method offered in `tf.nn.conv2d` and + `tf.nn.depthwise_conv2d`, but supports non-zero, reflect and symmetric + padding mode. For the even-sized filter, it pads one more value to the + right or the bottom side. + + Args: + image: A 4-D `Tensor` of shape `[batch_size, height, width, channels]`. + filter_shape: A `tuple`/`list` of 2 integers, specifying the height + and width of the 2-D filter. + mode: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". + The type of padding algorithm to use, which is compatible with + `mode` argument in `tf.pad`. For more details, please refer to + https://www.tensorflow.org/api_docs/python/tf/pad. + constant_values: A `scalar`, the pad value to use in "CONSTANT" + padding mode. + """ assert mode in ["CONSTANT", "REFLECT", "SYMMETRIC"] filter_height, filter_width = filter_shape pad_top = (filter_height - 1) // 2 @@ -62,7 +79,9 @@ def mean_filter2d(image, the height and width of the 2-D mean filter. Can be a single integer to specify the same value for all spatial dimensions. padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". - The type of padding algorithm to use. + The type of padding algorithm to use, which is compatible with + `mode` argument in `tf.pad`. For more details, please refer to + https://www.tensorflow.org/api_docs/python/tf/pad. constant_values: A `scalar`, the pad value to use in "CONSTANT" padding mode. name: A name for this operation (optional).